From c32bf3c9dbb7b16d7b0faeda72c1cca10445fb56 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 14:45:30 +0200 Subject: [PATCH 1/9] refactor(xmss): restructure subspec into focused modules and clean APIs Reorganize the XMSS subspec around single-responsibility files and tighten the public APIs while preserving all cryptographic behavior. Structure: - Split low-level primitives into field.py (base-P decomposition + secure sampling), poseidon.py (Poseidon1 engine + tweakable hash), and prf.py (SHAKE128 PRF); hashing.py is removed. - tweak_hash and hash_chain become methods on PoseidonXmss; the tweak NamedTuples move to types.py. - Merge message_hash.py + target_sum.py into encoding.py, subtree.py into merkle.py, and delete the rand.py/utils.py grab-bags. APIs: - TypeOneMultiSignature.aggregate becomes a classmethod taking (validator_index, pubkey, signature) tuples, dropping the optional participant bitfield and its runtime guards. - Move the greedy set-cover selection out of the multisig type into forks/lstar/aggregation_select.py. - Replace inline noqa for uppercase config properties with a file-scoped ruff ignore. Docs: - Tighten docstrings to the project /doc rules: one sentence per line, no backticks, WHY-focused comments, worked examples for the encoding. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../testing/src/consensus_testing/keys.py | 6 +- .../test_types/block_spec.py | 4 +- pyproject.toml | 4 + .../forks/lstar/aggregation_select.py | 51 ++ src/lean_spec/forks/lstar/spec.py | 24 +- src/lean_spec/subspecs/sync/service.py | 7 +- src/lean_spec/subspecs/validator/service.py | 8 +- src/lean_spec/subspecs/xmss/__init__.py | 10 +- src/lean_spec/subspecs/xmss/aggregation.py | 356 +++++----- src/lean_spec/subspecs/xmss/constants.py | 88 +-- src/lean_spec/subspecs/xmss/containers.py | 80 ++- src/lean_spec/subspecs/xmss/encoding.py | 177 +++++ src/lean_spec/subspecs/xmss/field.py | 32 + src/lean_spec/subspecs/xmss/interface.py | 516 +++++---------- src/lean_spec/subspecs/xmss/merkle.py | 480 ++++++++++++++ src/lean_spec/subspecs/xmss/message_hash.py | 161 ----- src/lean_spec/subspecs/xmss/poseidon.py | 297 +++++---- src/lean_spec/subspecs/xmss/prf.py | 275 +++----- src/lean_spec/subspecs/xmss/rand.py | 36 - src/lean_spec/subspecs/xmss/subtree.py | 615 ------------------ src/lean_spec/subspecs/xmss/target_sum.py | 89 --- src/lean_spec/subspecs/xmss/tweak_hash.py | 253 ------- src/lean_spec/subspecs/xmss/types.py | 138 ++-- src/lean_spec/subspecs/xmss/utils.py | 155 ----- .../forkchoice/test_attestation_target.py | 4 +- .../forkchoice/test_store_attestations.py | 82 +-- tests/lean_spec/helpers/builders.py | 18 +- .../subspecs/validator/test_service.py | 6 +- .../subspecs/xmss/test_aggregation.py | 132 +--- .../lean_spec/subspecs/xmss/test_interface.py | 11 +- .../subspecs/xmss/test_merkle_tree.py | 71 +- .../subspecs/xmss/test_message_hash.py | 53 +- tests/lean_spec/subspecs/xmss/test_prf.py | 23 +- .../subspecs/xmss/test_security_levels.py | 16 +- tests/lean_spec/subspecs/xmss/test_utils.py | 43 +- 35 files changed, 1693 insertions(+), 2628 deletions(-) create mode 100644 src/lean_spec/forks/lstar/aggregation_select.py create mode 100644 src/lean_spec/subspecs/xmss/encoding.py create mode 100644 src/lean_spec/subspecs/xmss/field.py create mode 100644 src/lean_spec/subspecs/xmss/merkle.py delete mode 100644 src/lean_spec/subspecs/xmss/message_hash.py delete mode 100644 src/lean_spec/subspecs/xmss/rand.py delete mode 100644 src/lean_spec/subspecs/xmss/subtree.py delete mode 100644 src/lean_spec/subspecs/xmss/target_sum.py delete mode 100644 src/lean_spec/subspecs/xmss/tweak_hash.py delete mode 100644 src/lean_spec/subspecs/xmss/utils.py diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 2832a5e81..1998f6ba8 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -69,7 +69,6 @@ Slot, Uint64, ValidatorIndex, - ValidatorIndices, ) KeyRole = Literal["attestation", "proposal"] @@ -534,13 +533,13 @@ def sign_and_aggregate( """ raw_xmss = [ ( + vid, self.get_public_keys(vid)[0], self.sign_attestation_data(vid, attestation_data), ) for vid in validator_ids ] return TypeOneMultiSignature.aggregate( - xmss_participants=ValidatorIndices(data=validator_ids).to_aggregation_bits(), children=[], raw_xmss=raw_xmss, message=hash_tree_root(attestation_data), @@ -599,8 +598,7 @@ def build_attestation_proofs( proofs.append( TypeOneMultiSignature.aggregate( children=[], - raw_xmss=list(zip(public_keys, signatures, strict=True)), - xmss_participants=agg.aggregation_bits, + raw_xmss=list(zip(validator_ids, public_keys, signatures, strict=True)), message=hash_tree_root(agg.data), slot=agg.data.slot, ) diff --git a/packages/testing/src/consensus_testing/test_types/block_spec.py b/packages/testing/src/consensus_testing/test_types/block_spec.py index f95b2f459..72aa4a0ec 100644 --- a/packages/testing/src/consensus_testing/test_types/block_spec.py +++ b/packages/testing/src/consensus_testing/test_types/block_spec.py @@ -276,7 +276,6 @@ def _sign_block( Complete signed block. """ block_root = hash_tree_root(final_block) - proposer_participants = ValidatorIndices(data=[proposer_index]).to_aggregation_bits() proposer_pubkey = key_manager.get_public_keys(proposer_index)[1] # The binding rejects placeholder bytes; if anything in the merged @@ -295,8 +294,7 @@ def _sign_block( ) proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=proposer_participants, + raw_xmss=[(proposer_index, proposer_pubkey, proposer_signature)], message=block_root, slot=self.slot, ) diff --git a/pyproject.toml b/pyproject.toml index e3bfe06ff..30f716da2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,10 +78,14 @@ known-first-party = ["lean_spec"] [tool.ruff.lint.per-file-ignores] "tests/**" = ["D"] +"src/lean_spec/subspecs/xmss/constants.py" = ["N802"] [tool.ty.environment] python-version = "3.12" +[tool.ty.src] +exclude = [".claude/"] + [tool.ty.terminal] error-on-warning = true diff --git a/src/lean_spec/forks/lstar/aggregation_select.py b/src/lean_spec/forks/lstar/aggregation_select.py new file mode 100644 index 000000000..7f3951503 --- /dev/null +++ b/src/lean_spec/forks/lstar/aggregation_select.py @@ -0,0 +1,51 @@ +"""Greedy proof selection for lstar block production.""" + +from lean_spec.subspecs.xmss.aggregation import TypeOneMultiSignature +from lean_spec.types import ValidatorIndex + + +def select_greedily( + *proof_sets: set[TypeOneMultiSignature] | None, +) -> tuple[list[TypeOneMultiSignature], set[ValidatorIndex]]: + """ + Greedy set-cover over Type-1 proofs maximizing validator coverage. + + Iterates the proof sets in order, repeatedly picking the proof with the + most uncovered validators until no further coverage is possible. + Earlier proof sets are prioritized so gossip-fresh proofs win over + already-known ones. + + The validator-index sets are materialized once per proof, not inside the + inner max key, so the loop runs in O(P * V) instead of O(P^2 * V). + + Args: + *proof_sets: One or more sets of Type-1 proofs, ordered by priority. + None entries are skipped. + + Returns: + The chosen proofs and the union of validator indices they cover. + """ + selected: list[TypeOneMultiSignature] = [] + covered: set[ValidatorIndex] = set() + + for proofs in proof_sets: + if not proofs: + continue + + # Materialize each proof's validator index set once. + # The greedy loop below would otherwise recompute it on every comparison. + coverage_of: dict[TypeOneMultiSignature, set[ValidatorIndex]] = { + p: set(p.participants.to_validator_indices()) for p in proofs + } + remaining = list(proofs) + + while remaining: + best = max(remaining, key=lambda p: len(coverage_of[p] - covered)) + new_coverage = coverage_of[best] - covered + if not new_coverage: + break + selected.append(best) + covered |= new_coverage + remaining.remove(best) + + return selected, covered diff --git a/src/lean_spec/forks/lstar/spec.py b/src/lean_spec/forks/lstar/spec.py index 66e3a2b2b..702ac3c02 100644 --- a/src/lean_spec/forks/lstar/spec.py +++ b/src/lean_spec/forks/lstar/spec.py @@ -5,6 +5,7 @@ from collections.abc import Iterable, Sequence, Set as AbstractSet from typing import Any, ClassVar +from lean_spec.forks.lstar.aggregation_select import select_greedily from lean_spec.forks.lstar.containers import ( AggregatedAttestation, AttestationData, @@ -55,7 +56,6 @@ Uint8, Uint64, ValidatorIndex, - ValidatorIndices, ) from ..protocol import ForkProtocol, SpecBlockType, SpecStateType @@ -763,7 +763,7 @@ def build_block( found_entries = True - selected, _ = TypeOneMultiSignature.select_greedily(proofs) + selected, _ = select_greedily(proofs) aggregated_signatures.extend(selected) for proof in selected: aggregated_attestations.append( @@ -836,7 +836,6 @@ def build_block( for proof in proofs ] sig = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=children, raw_xmss=[], message=hash_tree_root(att_data), @@ -1661,9 +1660,7 @@ def aggregate(self, store: LstarStore) -> tuple[LstarStore, list[SignedAggregate # New payloads go first because they represent uncommitted # work — known payloads fill remaining gaps. - child_proofs, covered = TypeOneMultiSignature.select_greedily( - new.get(data), known.get(data) - ) + child_proofs, covered = select_greedily(new.get(data), known.get(data)) # Phase 2: Fill # @@ -1689,17 +1686,6 @@ def aggregate(self, store: LstarStore) -> tuple[LstarStore, list[SignedAggregate if not raw_entries and len(child_proofs) < 2: continue - # Encode raw signers as a compact bitfield when present. - # Child-only aggregation (no raw signatures) must pass None. - if raw_entries: - xmss_participants = ValidatorIndices( - data=[vid for vid, _, _ in raw_entries] - ).to_aggregation_bits() - raw_xmss = [(pk, sig) for _, pk, sig in raw_entries] - else: - xmss_participants = None - raw_xmss = [] - # Phase 3: Aggregate # # Build the recursive proof tree. @@ -1719,11 +1705,11 @@ def aggregate(self, store: LstarStore) -> tuple[LstarStore, list[SignedAggregate ] # Hand everything to the XMSS subspec. + # Each fresh entry already carries its validator index alongside its key and signature. # Out comes a single proof covering all selected validators. proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=children, - raw_xmss=raw_xmss, + raw_xmss=raw_entries, message=hash_tree_root(data), slot=data.slot, ) diff --git a/src/lean_spec/subspecs/sync/service.py b/src/lean_spec/subspecs/sync/service.py index 305c8dd84..5d3e1b9fa 100644 --- a/src/lean_spec/subspecs/sync/service.py +++ b/src/lean_spec/subspecs/sync/service.py @@ -617,13 +617,13 @@ def _deconstruct_block_into_store( continue try: + # The split takes the bits from the block attestation this + # component binds, since the Rust binding does not return them. block_t1 = type_two.split_by_msg( message=data_root, public_keys_per_message=public_keys_per_message, + participants=att.aggregation_bits, ) - # split_by_msg returns an empty participant bitfield; restore - # the bits from the block attestation this component binds. - block_t1 = block_t1.model_copy(update={"participants": att.aggregation_bits}) if local_proofs: combined = TypeOneMultiSignature.aggregate( @@ -638,7 +638,6 @@ def _deconstruct_block_into_store( for child in (block_t1, *local_proofs) ], raw_xmss=[], - xmss_participants=None, message=data_root, slot=data.slot, ) diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index 2cf71bf01..c0d3099f9 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -50,7 +50,7 @@ from lean_spec.subspecs.xmss import TARGET_SIGNATURE_SCHEME from lean_spec.subspecs.xmss.aggregation import TypeOneMultiSignature, TypeTwoMultiSignature from lean_spec.subspecs.xmss.containers import PublicKey, Signature -from lean_spec.types import ByteList512KiB, Bytes32, Slot, Uint64, ValidatorIndex, ValidatorIndices +from lean_spec.types import ByteList512KiB, Bytes32, Slot, Uint64, ValidatorIndex from .constants import HYSTERESIS_BAND, NETWORK_STALL_THRESHOLD, SYNC_LAG_THRESHOLD from .registry import ValidatorEntry, ValidatorRegistry @@ -447,12 +447,10 @@ def _sign_block( proposer_pubkey = validators[validator_index].get_proposal_pubkey() # Wrap the proposer's raw XMSS signature into a singleton Type-1. - # The participant set is just the proposer index. - proposer_participants = ValidatorIndices(data=[validator_index]).to_aggregation_bits() + # The single fresh entry carries the proposer index alongside its key and signature. proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=proposer_participants, + raw_xmss=[(validator_index, proposer_pubkey, proposer_signature)], message=block_root, slot=block.slot, ) diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index a699229c9..3fd1f7ed3 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -1,8 +1,10 @@ -""" -This package provides a Python specification for the Generalized XMSS -hash-based signature scheme. +"""Generalized XMSS hash-based signature scheme. -It exposes the core data structures and the main interface functions. +References: + - Hash-Based Multi-Signatures for Post-Quantum Ethereum. + https://eprint.iacr.org/2025/055.pdf + - Aborting Random Oracles, How to Build Them, How to Use Them. + https://eprint.iacr.org/2026/016.pdf """ from .containers import PublicKey, SecretKey diff --git a/src/lean_spec/subspecs/xmss/aggregation.py b/src/lean_spec/subspecs/xmss/aggregation.py index 6b2ba0094..4a0816d0e 100644 --- a/src/lean_spec/subspecs/xmss/aggregation.py +++ b/src/lean_spec/subspecs/xmss/aggregation.py @@ -1,13 +1,12 @@ -""" -Multi-signature aggregation for the Lean Ethereum consensus spec. +"""Multi-signature aggregation over Generalized XMSS. Two proof shapes: -- Type-1: many validators on one message (one AttestationData, or one block root). -- Type-2: a merge of N Type-1 proofs over distinct messages. -""" +- Type-1: many validators on a single message (one attestation, or one block root). +- Type-2: a merge of several Type-1 proofs over distinct messages. -from collections.abc import Sequence +The Rust binding owns proof construction and cryptographic checks. +""" from lean_multisig_py import ( aggregate_type_1, @@ -18,7 +17,7 @@ verify_type_2_with_messages, ) -from lean_spec.config import LEAN_ENV, LeanEnvMode +from lean_spec.config import LEAN_ENV from lean_spec.types import ( AggregationBits, ByteList512KiB, @@ -28,33 +27,35 @@ ValidatorIndex, ValidatorIndices, ) -from lean_spec.types.boolean import Boolean from .containers import PublicKey, Signature -LOG_INV_RATE_TEST = 1 -""" -Inverse rate exponent for test mode (fastest, biggest proofs). +LOG_INV_RATE: int = 1 if LEAN_ENV == "test" else 2 +"""Inverse-rate exponent forwarded to the SNARK backend. -This parameter is forwarded to `lean_multisig_py` prover and controls a performance/size trade-off: - -- Lower values generate proofs faster but increase proof size. -- Higher values reduce proof size but increase prover work. +- A smaller rate trades verifier cost for prover speed. +- Test mode favors prover speed. """ -LOG_INV_RATE_PROD = 2 -"""Inverse rate exponent for production mode (balanced speed vs proof size).""" +# The environment is fixed for the lifetime of the process. +# +# One setup call covers every aggregation, verification, split, and merge below. +# +# Per-call invocations then default to the mode established here. +setup_prover(mode=LEAN_ENV) class AggregationError(Exception): - """Raised when signature aggregation, merging, splitting, or verification fails.""" + """Raised when aggregation, merging, splitting, or verification fails.""" class TypeOneMultiSignature(Container): - """A single-message proof aggregating signatures from many validators. + """Single-message proof aggregating signatures from many validators. + + Every validator signs the same message for the same slot. - The signed message and slot are rederived by the verifier from the - block body it already trusts, so they live outside the proof envelope. + The message and slot stay outside the proof. + The verifier rederives them from the block body it already trusts. """ participants: AggregationBits @@ -63,275 +64,274 @@ class TypeOneMultiSignature(Container): proof: ByteList512KiB """Aggregated proof bytes in compact no-pubkeys representation.""" - @staticmethod - def select_greedily( - *proof_sets: set["TypeOneMultiSignature"] | None, - ) -> tuple[list["TypeOneMultiSignature"], set[ValidatorIndex]]: - """Greedy set-cover over Type-1 proofs to maximise validator coverage. - - Repeatedly selects the proof covering the most uncovered validators - until no proof adds new coverage. Earlier proof sets are - prioritised: gossip-fresh proofs win over already-known ones. - """ - selected: list[TypeOneMultiSignature] = [] - covered: set[ValidatorIndex] = set() - - for proofs in proof_sets: - if not proofs: - continue - - remaining = list(proofs) - - while remaining: - best = max( - remaining, - key=lambda p: len(set(p.participants.to_validator_indices()) - covered), - ) - new_coverage = set(best.participants.to_validator_indices()) - covered - - if not new_coverage: - break - - selected.append(best) - covered |= new_coverage - remaining.remove(best) - - return selected, covered - - @staticmethod + @classmethod def aggregate( - children: Sequence[tuple["TypeOneMultiSignature", Sequence[PublicKey]]], - raw_xmss: Sequence[tuple[PublicKey, Signature]], - xmss_participants: AggregationBits | None, + cls, + children: list[tuple["TypeOneMultiSignature", list[PublicKey]]], + raw_xmss: list[tuple[ValidatorIndex, PublicKey, Signature]], message: Bytes32, slot: Slot, - mode: LeanEnvMode | None = None, ) -> "TypeOneMultiSignature": - """Aggregate raw XMSS signatures and child Type-1 proofs into one Type-1 proof. + """Fold fresh signatures and child proofs into one single-message proof. - Proof bytes are stored in compact no-pubkeys form. Participant identity is - tracked separately in participants (attestation bits on the wire). - """ - if not raw_xmss and not children: - raise AggregationError("At least one raw signature or child proof is required") + # Overview - if raw_xmss and xmss_participants is None: - raise AggregationError("xmss_participants is required when raw_xmss is provided") + Two kinds of contribution merge into one proof. - if not raw_xmss and len(children) < 2: - raise AggregationError( - "At least two child proofs are required when no raw signatures are provided" - ) + - A fresh signer contributes a single raw signature. + - A child proof contributes an already-aggregated bundle of signers. - aggregated_validator_ids: set[ValidatorIndex] = set() - if xmss_participants is not None: - aggregated_validator_ids.update(xmss_participants.to_validator_indices()) + The result names the union of every contributing validator. + The prover compresses all contributions into one proof over the shared message. - if len(aggregated_validator_ids) != len(raw_xmss): - raise AggregationError("Raw signature count does not match XMSS participant count") + # Why the index travels with each fresh signer - # Include child participants in the aggregated participants - for child, _ in children: - aggregated_validator_ids.update(child.participants.to_validator_indices()) - participants = ValidatorIndices(data=sorted(aggregated_validator_ids)).to_aggregation_bits() + A public key carries no validator index on its own. + Pairing the index with each fresh entry lets the bitfield be derived, not passed in. + An empty list of fresh signers simply contributes no indices. - mode = mode or LEAN_ENV - setup_prover(mode=mode) - log_inv_rate = LOG_INV_RATE_TEST if mode == "test" else LOG_INV_RATE_PROD + Args: + children: Child proofs, each paired with the public keys it names. + raw_xmss: Fresh entries, each carrying its validator index, public key, and signature. + message: The 32-byte message every signer signed. + slot: The slot every signer signed for. - raw_pubkeys_ssz = [pk.encode_bytes() for pk, _ in raw_xmss] - raw_signatures_ssz = [sig.encode_bytes() for _, sig in raw_xmss] + Returns: + A single-message proof covering the union of all participants. - children_bytes: list[tuple[list[bytes], bytes]] = [] - for idx, (child, child_public_keys_raw) in enumerate(children): - child_public_keys = list(child_public_keys_raw) - expected = child.participants.data.count(Boolean(1)) - if len(child_public_keys) != expected: - raise AggregationError( - f"Type-1 aggregate child {idx} expected {expected} pubkeys, " - f"got {len(child_public_keys)}" - ) - - child_pks_ssz = [pk.encode_bytes() for pk in child_public_keys] - child_wire = bytes(child.proof.data) - if not child_wire: - raise AggregationError(f"Child proof {idx} has empty proof bytes") - children_bytes.append((child_pks_ssz, child_wire)) + Raises: + AggregationError: When the prover rejects the inputs. + """ + # Phase 1: union every contributing validator index. + # + # Fresh signers bring their own index. + # Child proofs expose theirs through the participant bitfield. + all_indices = {vid for vid, _, _ in raw_xmss}.union( + *(child.participants.to_validator_indices() for child, _ in children) + ) + participants = ValidatorIndices(data=sorted(all_indices)).to_aggregation_bits() + + # Phase 2: serialize inputs to the prover's wire format. + raw_pubkeys_ssz = [pk.encode_bytes() for _, pk, _ in raw_xmss] + raw_signatures_ssz = [sig.encode_bytes() for _, _, sig in raw_xmss] + children_bytes = [ + ([pk.encode_bytes() for pk in pubkeys], bytes(child.proof.data)) + for child, pubkeys in children + ] + # Phase 3: hand off to the Rust prover. + # The mode argument routes the call to the matching backend bytecode. try: _, type1_wire = aggregate_type_1( raw_pubkeys_ssz, raw_signatures_ssz, bytes(message), int(slot), - log_inv_rate, - children_bytes if children_bytes else None, - mode=mode, + LOG_INV_RATE, + children_bytes or None, + mode=LEAN_ENV, ) except Exception as exc: - raise AggregationError(f"Type-1 aggregation failed: {exc}") from exc + raise AggregationError(str(exc)) from exc - return TypeOneMultiSignature( - participants=participants, - proof=ByteList512KiB(data=type1_wire), - ) + return cls(participants=participants, proof=ByteList512KiB(data=type1_wire)) def verify( self, - public_keys: Sequence[PublicKey], + public_keys: list[PublicKey], message: Bytes32, slot: Slot, - mode: LeanEnvMode | None = None, ) -> None: - """Verify this single-message Type-1 proof against a resolved set of pubkeys.""" - mode = mode or LEAN_ENV - setup_prover(mode=mode) + """Verify this single-message Type-1 proof against a pubkey set. - expected = self.participants.data.count(Boolean(1)) + Args: + public_keys: Pubkeys for the validators named by participants. + message: Message bound by the proof. + slot: Slot bound by the proof. + + Raises: + AggregationError: When the pubkey count does not match the bitfield + or the Rust verifier rejects the proof. + """ + # The bitfield names one validator per set bit. + # The caller must supply exactly that many keys, in the same order. + # A miscount would otherwise fail deep in the verifier with an opaque error. + expected = len(self.participants.to_validator_indices()) if len(public_keys) != expected: raise AggregationError( f"Type-1 verify expected {expected} pubkeys for participants, " f"got {len(public_keys)}" ) - pks_ssz = [pk.encode_bytes() for pk in public_keys] + # Hand the resolved keys, message, and slot to the Rust verifier. + # The mode argument selects the matching backend bytecode. try: verify_type_1( - pks_ssz, + [pk.encode_bytes() for pk in public_keys], bytes(message), int(slot), bytes(self.proof.data), - mode=mode, + mode=LEAN_ENV, ) except Exception as exc: raise AggregationError(f"Type-1 verification failed: {exc}") from exc class TypeTwoMultiSignature(Container): - """A merged proof covering many distinct messages. + """Merged proof covering many distinct messages. - On the wire a SignedBlock carries the SSZ-serialised form of this - container as its single proof blob. + Each component is a single-message proof over its own message. + Merging binds the components into one proof the block can carry whole. + + A signed block stores this proof as a single serialized blob. """ proof: ByteList512KiB """Compact no-pubkeys serialized Type-2 proof bytes.""" - @staticmethod + @classmethod def aggregate( - parts: Sequence[TypeOneMultiSignature], - public_keys_per_part: Sequence[Sequence[PublicKey]] | None = None, - mode: LeanEnvMode | None = None, + cls, + parts: list[TypeOneMultiSignature], + public_keys_per_part: list[list[PublicKey]], ) -> "TypeTwoMultiSignature": - """Merge several Type-1 proofs (each over a distinct message) into one Type-2 proof. + """Merge several single-message proofs over distinct messages into one. + + # Why the public keys are passed in + + - A merged proof stores no public keys. + - The prover needs them as external context to fold the components together. + - They cannot be recovered from the proofs, so the caller supplies them. - The returned Type-2 proof bytes are stored in compact no-pubkeys form. + Args: + parts: The single-message proofs to merge, one per distinct message. + public_keys_per_part: Public keys for each component, in the same order as the proofs. + + Returns: + A merged proof binding every component to its own message. + + Raises: + AggregationError: When no proofs are given, a pubkey list disagrees + with its participant count, or the prover rejects the inputs. """ if not parts: raise AggregationError("Type-2 aggregate requires at least one Type-1 input") - mode = mode or LEAN_ENV - setup_prover(mode=mode) - log_inv_rate = LOG_INV_RATE_TEST if mode == "test" else LOG_INV_RATE_PROD - - if public_keys_per_part is not None and len(public_keys_per_part) != len(parts): - raise AggregationError( - f"Type-2 aggregate expected pubkeys for {len(parts)} parts, " - f"got {len(public_keys_per_part)}" - ) - + # Each component carries the public keys named by its bitfield, in the same order. + # + # A miscount would otherwise fail deep in the prover with an opaque error. type1_entries: list[tuple[list[bytes], bytes]] = [] - for idx, part in enumerate(parts): - expected = part.participants.data.count(Boolean(1)) - if public_keys_per_part is None: - raise AggregationError( - "public_keys_per_part is required when Type-1 proofs are stored without pubkeys" - ) - pubkeys = list(public_keys_per_part[idx]) + for idx, (part, pubkeys) in enumerate(zip(parts, public_keys_per_part, strict=True)): + expected = len(part.participants.to_validator_indices()) if len(pubkeys) != expected: raise AggregationError( f"Type-2 aggregate entry {idx} expected {expected} pubkeys, got {len(pubkeys)}" ) - pks_ssz = [pk.encode_bytes() for pk in pubkeys] - type1_entries.append((pks_ssz, bytes(part.proof.data))) + type1_entries.append(([pk.encode_bytes() for pk in pubkeys], bytes(part.proof.data))) + # Hand the per-component keys and proof bytes to the Rust prover. + # + # The mode argument selects the matching backend bytecode. try: - _, type2_wire = merge_many_type_1(type1_entries, log_inv_rate, mode=mode) + _, type2_wire = merge_many_type_1(type1_entries, LOG_INV_RATE, mode=LEAN_ENV) except Exception as exc: - raise AggregationError(f"Type-2 aggregation failed: {exc}") from exc + raise AggregationError(str(exc)) from exc - return TypeTwoMultiSignature(proof=ByteList512KiB(data=type2_wire)) + return cls(proof=ByteList512KiB(data=type2_wire)) def split_by_msg( self, message: Bytes32, - public_keys_per_message: Sequence[Sequence[PublicKey]], - mode: LeanEnvMode | None = None, + public_keys_per_message: list[list[PublicKey]], + participants: AggregationBits, ) -> TypeOneMultiSignature: - """Recover the Type-1 proof bound to a specific message from this Type-2 merge. + """Recover the Type-1 proof bound to one message from this Type-2 merge. - public_keys_per_message defines the per-component pubkey layout the - Type-2 was built with. - """ - mode = mode or LEAN_ENV - setup_prover(mode=mode) - log_inv_rate = LOG_INV_RATE_TEST if mode == "test" else LOG_INV_RATE_PROD + # Why the layout and participants are passed in + + - A merged proof stores neither the public keys nor the participant bitfields. + - The prover needs the original key layout to isolate one component. + - The caller supplies both, drawn from the block attestation this component binds. + Args: + message: Message that selects the Type-1 component. + public_keys_per_message: Pubkey layout this Type-2 was built with. + participants: Bitfield naming the validators of the recovered component. + + Returns: + The Type-1 proof bound to the message. + + Raises: + AggregationError: When the Rust binding rejects the split. + """ + # Each component carries the public keys named by its bitfield, in the same order. pub_keys_per_component_ssz: list[list[bytes]] = [ [pk.encode_bytes() for pk in pks] for pks in public_keys_per_message ] + # Hand the key layout, merged proof, and selector message to the Rust prover. + # + # The mode argument selects the matching backend bytecode. try: _, type1_wire = split_type_2_by_msg( pub_keys_per_component_ssz, bytes(self.proof.data), bytes(message), - log_inv_rate, - mode=mode, + LOG_INV_RATE, + mode=LEAN_ENV, ) except Exception as exc: - raise AggregationError(f"Type-2 split-by-message failed: {exc}") from exc + raise AggregationError(f"Type-2 split failed: {exc}") from exc return TypeOneMultiSignature( - participants=AggregationBits(data=[]), + participants=participants, proof=ByteList512KiB(data=type1_wire), ) def verify( self, - public_keys_per_message: Sequence[Sequence[PublicKey]], - messages: Sequence[tuple[Bytes32, Slot]], - mode: LeanEnvMode | None = None, + public_keys_per_message: list[list[PublicKey]], + messages: list[tuple[Bytes32, Slot]], ) -> None: - """Verify this multi-message Type-2 proof. - - Each entry of public_keys_per_message corresponds to one Type-1 - component merged into this Type-2. - The parallel messages entry binds that component to a specific - message hash and slot. - Without this binding the proof would verify against any attacker - chosen attestation data that resolves to the same pubkeys. - """ - mode = mode or LEAN_ENV - setup_prover(mode=mode) + """Verify this multi-message proof against its per-component bindings. + + # The message bindings + Each component is checked against one message and slot supplied by the caller. + Without that binding the proof would accept attacker-chosen data resolving to the same keys. + The parallel lists pin every component to the message it actually signed. + + Args: + public_keys_per_message: Public keys for each component, in component order. + messages: Message-slot pair each component is bound to, parallel to the keys. + + Raises: + AggregationError: When the two lists disagree in length, or the verifier rejects. + """ + # Each component needs exactly one message-slot binding. + # + # A length mismatch would leave components unbound or misaligned. if len(messages) != len(public_keys_per_message): raise AggregationError( f"Type-2 verify expected {len(public_keys_per_message)} message bindings, " f"got {len(messages)}" ) + # Serialize the key layout and the per-component message bindings. pub_keys_per_component_ssz: list[list[bytes]] = [ [pk.encode_bytes() for pk in pks] for pks in public_keys_per_message ] expected_messages = [(bytes(msg), int(slot)) for msg, slot in messages] + # Hand the layout, bindings, and merged proof to the Rust verifier. + # + # The mode argument selects the matching backend bytecode. try: verify_type_2_with_messages( pub_keys_per_component_ssz, expected_messages, bytes(self.proof.data), - mode=mode, + mode=LEAN_ENV, ) except Exception as exc: raise AggregationError(f"Type-2 verification failed: {exc}") from exc diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index f6dd295a1..7f5c193f0 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -1,15 +1,4 @@ -""" -Defines the cryptographic constants and configuration presets for the -XMSS spec. - -This specification corresponds to the "hashing-optimized" Top Level Target Sum -instantiation from the canonical Rust implementation -(production instantiation). - -We also provide a test instantiation for testing purposes. -""" - -from __future__ import annotations +"""Cryptographic constants and configuration presets for the XMSS spec.""" import math from typing import Final @@ -26,51 +15,34 @@ class XmssConfig(StrictBaseModel): """A model holding the configuration constants for an XMSS preset.""" - # --- Core Scheme Configuration --- MESSAGE_LENGTH: int """The length in bytes for all messages to be signed.""" LOG_LIFETIME: int """The base-2 logarithm of the scheme's maximum lifetime.""" - @property - def LIFETIME(self) -> Uint64: # noqa: N802 - """ - The maximum number of slots supported by this configuration. - - An individual key pair can be active for a smaller sub-range. - """ - return Uint64(1 << self.LOG_LIFETIME) - DIMENSION: int - """The total number of hash chains, `v`.""" + """The total number of hash chains, v.""" BASE: int """The alphabet size for the digits of the encoded message.""" Z: int - """Number of base-`BASE` digits extracted from each field element.""" + """Number of base-BASE digits extracted from each field element.""" Q: int - """Quotient such that `Q * BASE^Z == P - 1`.""" + """Quotient such that Q * BASE^Z == P - 1.""" TARGET_SUM: int """The required sum of all codeword chunks for a signature to be valid.""" MAX_TRIES: int - """ - How often one should try at most to resample a random value. - - This is currently based on experiments with the Rust implementation. - Should probably be modified in production. - """ + """How often one should try at most to resample a random value.""" PARAMETER_LEN: int - """ - The length of the public parameter `P`. + """The length of the public parameter P. - It is used to specialize the hash function. - """ + It is used to specialize the hash function.""" TWEAK_LEN_FE: int """The length of a domain-separating tweak.""" @@ -79,7 +51,7 @@ def LIFETIME(self) -> Uint64: # noqa: N802 """The length of a message after being encoded into field elements.""" RAND_LEN_FE: int - """The length of the randomness `rho` used during message encoding.""" + """The length of the randomness rho used during message encoding.""" HASH_LEN_FE: int """The output length of the main tweakable hash function.""" @@ -88,33 +60,48 @@ def LIFETIME(self) -> Uint64: # noqa: N802 """The capacity of the Poseidon1 sponge, defining its security level.""" @model_validator(mode="after") - def _validate_decomposition(self) -> XmssConfig: + def _validate_decomposition(self) -> "XmssConfig": """Verify that Q * BASE^Z == P - 1.""" if self.Q * self.BASE**self.Z != P - 1: raise ValueError(f"Q * BASE^Z must equal P-1={P - 1}") return self @property - def MH_HASH_LEN_FE(self) -> int: # noqa: N802 + def LIFETIME(self) -> Uint64: + """The maximum number of slots supported by this configuration. + + An individual key pair can be active for a smaller sub-range. + """ + return Uint64(1 << self.LOG_LIFETIME) + + @property + def MH_HASH_LEN_FE(self) -> int: """Number of Poseidon output field elements needed for the aborting decode.""" return math.ceil(self.DIMENSION / self.Z) @property - def SIGNATURE_LEN_BYTES(self) -> int: # noqa: N802 - """ - The SSZ-encoded size of a signature in bytes. + def SIGNATURE_LEN_BYTES(self) -> int: + """The SSZ-encoded size of a signature in bytes. - Includes raw field data plus SSZ offset overhead for variable-size fields: + # Layout - - Signature container: 2 offsets (path, hashes) - - HashTreeOpening container: 1 offset (siblings) + authentication path : one sibling digest per tree level (variable) + encoding randomness : fixed run of field elements (fixed) + released chain ends : one digest per hash chain (variable) """ - # Raw data sizes + # One sibling digest per level climbed from leaf to root. path_siblings_size = self.LOG_LIFETIME * self.HASH_LEN_FE * P_BYTES rho_size = self.RAND_LEN_FE * P_BYTES + # One released chain end per chain, so the count is the scheme dimension. hashes_size = self.DIMENSION * self.HASH_LEN_FE * P_BYTES - # SSZ offset overhead: 3 variable fields × 4 bytes each + # SSZ writes a four-byte offset ahead of each variable-length field. + # + # path -> offset 1 (top level) + # chain ends -> offset 2 (top level) + # siblings -> offset 3 (nested inside the path) + # + # The randomness is fixed-length, so it carries no offset. ssz_offset_overhead = 3 * BYTES_PER_LENGTH_OFFSET return path_siblings_size + rho_size + hashes_size + ssz_offset_overhead @@ -170,10 +157,5 @@ def SIGNATURE_LEN_BYTES(self) -> int: # noqa: N802 PRF_KEY_LENGTH: Final = 32 """The length of the PRF secret key in bytes.""" -_LEAN_ENV_TO_CONFIG = { - "test": TEST_CONFIG, - "prod": PROD_CONFIG, -} - -TARGET_CONFIG: Final = _LEAN_ENV_TO_CONFIG[LEAN_ENV] -"""The active XMSS configuration based on LEAN_ENV environment variable.""" +TARGET_CONFIG: Final = TEST_CONFIG if LEAN_ENV == "test" else PROD_CONFIG +"""Active configuration selected at import time from the LEAN_ENV environment variable.""" diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index 2057bc6c8..4db8dac77 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -11,7 +11,7 @@ from ...types.container import Container from ...types.exceptions import SSZError from .constants import TARGET_CONFIG -from .subtree import HashSubTree +from .merkle import HashSubTree from .types import ( HashDigestList, HashDigestVector, @@ -33,7 +33,7 @@ class PublicKey(Container): class Signature(Container): - """A single XMSS signature for one (slot, message) under one public key.""" + """A single XMSS signature for one slot and message under one public key.""" path: HashTreeOpening """Authentication path from the one-time key up to the Merkle root.""" @@ -58,56 +58,57 @@ def get_byte_length(cls) -> int: @model_serializer(mode="plain", when_used="json") def _serialize_as_bytes(self) -> str: - """Serialize as "0x"-prefixed hex of the SSZ bytes for JSON output.""" + """Serialize as a 0x-prefixed hex string for JSON output.""" return "0x" + self.encode_bytes().hex() class SecretKey(Container): - """ - Private state of an XMSS key pair. MUST BE KEPT CONFIDENTIAL. - - The tree of one-time keys is split into a top tree plus many bottom trees. - - Each bottom tree: - - has width W = 2^(LOG_LIFETIME / 2), - - covers W slots. + """Private state of an XMSS key pair. - Bottom tree i covers slots [i*W, (i+1)*W). + Must be kept confidential. - The signer keeps the full top tree resident, plus two adjacent bottom trees. + # Tree layout - This means that we have a sliding window of 2W consecutive slots. + - The one-time keys split into one top tree over many bottom trees. + - Each bottom tree spans W = 2^(LOG_LIFETIME / 2) slots. + - Bottom tree i covers the W slots starting at i times W. - These slots can be signed immediately without recomputing anything. + # Sliding window - Memory stays O(sqrt(LIFETIME)) regardless of how long the key lives. + - The signer keeps the top tree resident plus two adjacent bottom trees. + - That window can sign 2W consecutive slots. + - Resident memory stays near the square root of the lifetime. """ prf_key: PRFKey - """Master secret seed; every one-time key is derived from this.""" + """Master secret seed. + + Every one-time key is derived from this. + """ parameter: Parameter """Public parameter mirrored here so signing is self-contained.""" activation_slot: Slot - """ - First slot this key can sign for. + """First slot this key can sign for. - Aligned down to a multiple of W so each bottom tree covers exactly W slots. + - Aligned down to a multiple of W. + - Each bottom tree then covers exactly W slots. """ num_active_slots: Uint64 - """ - Number of consecutive slots this key can sign for. + """Number of consecutive slots this key can sign for. - Rounded up to a multiple of W, with a minimum of 2W so the prepared window always fits. + - Rounded up to a multiple of W, with a minimum of 2W. + - The prepared window then always fits. """ top_tree: HashSubTree """ Full top tree, always resident. - Its lowest layer holds the bottom-tree roots; its top is the public-key root. + - Its lowest layer holds the bottom-tree roots. + - Its top layer is the public-key root. """ left_bottom_tree_index: Uint64 @@ -137,7 +138,10 @@ class KeyPair(StrictBaseModel): @field_validator("public_key", mode="before") @classmethod def _decode_public_key(cls, value: object) -> object: - """Decode a hex string into a PublicKey; pass other inputs through.""" + """Decode hex strings to a public key. + + Other input shapes pass through unchanged. + """ if not isinstance(value, str): return value try: @@ -148,7 +152,10 @@ def _decode_public_key(cls, value: object) -> object: @field_validator("secret_key", mode="before") @classmethod def _decode_secret_key(cls, value: object) -> object: - """Decode a hex string into a SecretKey; pass other inputs through.""" + """Decode hex strings to a secret key. + + Other input shapes pass through unchanged. + """ if not isinstance(value, str): return value try: @@ -163,25 +170,16 @@ def _encode_hex(self, value: PublicKey | SecretKey) -> str: class ValidatorKeyPair(StrictBaseModel): - """ - Two independent XMSS key pairs for one validator's two signing roles. + """Two independent XMSS key pairs for one validator's two signing roles. A validator signs two messages per slot: - - one attestation (always), - - one block proposal (only when chosen as proposer). - - A one-time signature exhausts a leaf, so a single XMSS key cannot serve - both roles in the same slot. - Splitting into two independent pairs lets a validator sign both - messages from independent Winternitz chains. - - JSON shape on disk: + - one attestation, always, + - one block proposal, only when chosen as proposer. - { - "attestation_keypair": {"public_key": "", "secret_key": ""}, - "proposal_keypair": {"public_key": "", "secret_key": ""} - } + A one-time signature exhausts a leaf. + So one key cannot cover both roles in the same slot. + Two independent pairs let each role sign from its own Winternitz chains. """ attestation_keypair: KeyPair diff --git a/src/lean_spec/subspecs/xmss/encoding.py b/src/lean_spec/subspecs/xmss/encoding.py new file mode 100644 index 000000000..0dd35aba7 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/encoding.py @@ -0,0 +1,177 @@ +"""Message-to-codeword pipeline for the Generalized XMSS scheme. + +# Overview + +The pipeline has two layers: + +- Encoding maps a message to a codeword, the vector of digits a signature commits to. +- Decoding is the inner step that unpacks hash field elements into those digits. + +A codeword is a vertex of a high-dimensional hypercube. +Encoding accepts the vertex only when its digits lie on the target-sum layer. +A signer retries with fresh randomness until the filter accepts. + +Concretely, the test preset uses 4 digits in base 8 with target sum 6: + +- The vector [5, 0, 1, 0] sums to 6, lands on the layer, and is accepted. +- The vector [2, 2, 2, 2] sums to 8, misses the layer, and forces a retry. + +# Construction + +The hypercube decode is the aborting encoding of Aborting Random Oracles, Section 6.1. +https://eprint.iacr.org/2026/016.pdf + +The target-sum filter is the top-level acceptance test from the canonical Rust instantiation. + +# Why the decode can abort + +The decode turns each uniform field element into uniform base-BASE digits. +Uniform output is what lets the security analysis model the hash as a random oracle. + +A field element takes one of the P values from 0 to P - 1. +The prime is chosen so that P - 1 = Q * BASE^Z. +The values 0 to P - 2 therefore form BASE^Z groups of Q consecutive integers. + +Integer division by Q maps each group to one quotient in 0 to BASE^Z - 1: + + 0 .. Q-1 -> quotient 0 + Q .. 2Q-1 -> quotient 1 + ... + P-1-Q .. P-2 -> quotient BASE^Z - 1 + P-1 -> no quotient + +Each quotient expands into Z base-BASE digits. +Every quotient is equally likely, which makes the digits uniform. + +The value P - 1 falls outside every group. +The decode rejects it, a rare event near 4.7e-10 that barely affects signing. +""" + +from lean_spec.types import Bytes32, Uint64 + +from ..koalabear import Fp +from .constants import TWEAK_PREFIX_MESSAGE, XmssConfig +from .field import int_to_base_p +from .poseidon import PoseidonXmss +from .types import Parameter, Randomness + + +def encode_message(config: XmssConfig, message: Bytes32) -> list[Fp]: + """Encode a 32-byte message into field elements via base-P decomposition. + + The bytes are read little-endian as a single integer. + """ + acc = int.from_bytes(message, "little") + return int_to_base_p(acc, config.MSG_LEN_FE) + + +def encode_epoch(config: XmssConfig, epoch: Uint64) -> list[Fp]: + """Encode the epoch and the message-hash subdomain into field elements. + + The 8-bit prefix separates the message-hash subdomain from the chain and tree subdomains. + """ + # Layout: + # + # (epoch << 8) | MESSAGE_PREFIX + acc = (int(epoch) << 8) | TWEAK_PREFIX_MESSAGE + return int_to_base_p(acc, config.TWEAK_LEN_FE) + + +def aborting_decode(config: XmssConfig, field_elements: list[Fp]) -> list[int] | None: + """Reject-sample each field element into base-BASE digits. + + For each element A_i: + + 1. If A_i >= Q * BASE^Z, that is A_i == P - 1, abort and return None. + 2. Compute d_i = A_i // Q in [0, BASE^Z - 1]. + 3. Emit Z base-BASE digits of d_i, least significant first. + + Return the first DIMENSION digits. + """ + threshold = config.Q * config.BASE**config.Z + + digits: list[int] = [] + for fe in field_elements: + a = int(fe) + + # The only rejection case is A_i == P - 1. + if a >= threshold: + return None + + # Quotient by Q strips the residue. + # The remainder is uniform in [0, BASE^Z - 1]. + d = a // config.Q + for _ in range(config.Z): + d, digit = divmod(d, config.BASE) + digits.append(digit) + + return digits[: config.DIMENSION] + + +def message_hash( + poseidon: PoseidonXmss, + config: XmssConfig, + parameter: Parameter, + epoch: Uint64, + rho: Randomness, + message: Bytes32, +) -> list[int] | None: + """Hash the inputs with Poseidon1 and decode into a candidate codeword. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + parameter: Public parameter P. + epoch: Current epoch. + rho: Per-attempt randomness. + message: Message being signed. + + Returns: + Codeword of DIMENSION digits in [0, BASE-1], or None on rejection. + """ + # Encode the message and epoch as field elements before hashing. + message_fe = encode_message(config, message) + epoch_fe = encode_epoch(config, epoch) + + # One Poseidon1 call produces enough output for the aborting decode. + base_input = message_fe + parameter.elements + epoch_fe + rho.elements + poseidon_output = poseidon.compress(base_input, 24, config.MH_HASH_LEN_FE) + + return aborting_decode(config, poseidon_output) + + +def target_sum_encode( + poseidon: PoseidonXmss, + config: XmssConfig, + parameter: Parameter, + message: Bytes32, + rho: Randomness, + epoch: Uint64, +) -> list[int] | None: + """Encode a message into a codeword if it meets the target sum. + + The signer retries with fresh randomness on rejection. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + parameter: Public parameter for domain separation. + message: Message being signed. + rho: Per-attempt randomness. + epoch: Current epoch. + + Returns: + Codeword on success, None when the attempt must be retried. + """ + # Phase 1: aborting hypercube decode of the Poseidon1 output. + codeword_candidate = message_hash(poseidon, config, parameter, epoch, rho, message) + if codeword_candidate is None: + return None + + # Phase 2: target-sum acceptance condition. + # A valid codeword is a vertex on the hypercube layer whose digit sum is TARGET_SUM. + if sum(codeword_candidate) == config.TARGET_SUM: + return codeword_candidate + + # The caller retries with new randomness. + return None diff --git a/src/lean_spec/subspecs/xmss/field.py b/src/lean_spec/subspecs/xmss/field.py new file mode 100644 index 000000000..053b0b8bc --- /dev/null +++ b/src/lean_spec/subspecs/xmss/field.py @@ -0,0 +1,32 @@ +"""Field-element decomposition and secure sampling for the Generalized XMSS scheme.""" + +import secrets + +from ..koalabear import Fp, P +from .constants import XmssConfig +from .types import HashDigestVector, Parameter + + +def int_to_base_p(value: int, num_limbs: int) -> list[Fp]: + """Decompose an integer into a fixed-size list of base-P field elements.""" + acc = value + limbs: list[Fp] = [] + for _ in range(num_limbs): + limbs.append(Fp(value=acc)) + acc //= P + return limbs + + +def random_field_elements(length: int) -> list[Fp]: + """Sample a list of secure-random field elements in [0, P).""" + return [Fp(value=secrets.randbelow(P)) for _ in range(length)] + + +def random_parameter(config: XmssConfig) -> Parameter: + """Sample a fresh public parameter for one XMSS key pair.""" + return Parameter(data=random_field_elements(config.PARAMETER_LEN)) + + +def random_domain(config: XmssConfig) -> HashDigestVector: + """Sample a fresh hash-digest-sized vector of field elements.""" + return HashDigestVector(data=random_field_elements(config.HASH_LEN_FE)) diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 4a45a152f..be393beba 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -1,175 +1,145 @@ -""" -Defines the core interface for the Generalized XMSS signature scheme. - -Specification for the high-level functions (`key_gen`, `sign`, `verify`). - -This constitutes the public API of the signature scheme. -""" - -from __future__ import annotations +"""Public interface for the Generalized XMSS signature scheme.""" from lean_spec.config import LEAN_ENV -from lean_spec.subspecs.xmss.target_sum import ( - PROD_TARGET_SUM_ENCODER, - TEST_TARGET_SUM_ENCODER, - TargetSumEncoder, -) from lean_spec.types import Bytes32, Slot, StrictBaseModel, Uint64 -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - XmssConfig, -) +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .containers import KeyPair, PublicKey, SecretKey, Signature -from .prf import PROD_PRF, TEST_PRF, Prf -from .rand import PROD_RAND, TEST_RAND, Rand -from .subtree import HashSubTree, combined_path, verify_path -from .tweak_hash import ( - PROD_TWEAK_HASHER, - TEST_TWEAK_HASHER, - TweakHasher, -) +from .encoding import target_sum_encode +from .field import random_parameter +from .merkle import HashSubTree, combined_path, verify_path +from .poseidon import PROD_POSEIDON, TEST_POSEIDON, PoseidonXmss +from .prf import prf_apply, prf_get_randomness, prf_key_gen from .types import HashDigestList, HashDigestVector -from .utils import expand_activation_time - - -class GeneralizedXmssScheme(StrictBaseModel): - """ - Instance of the Generalized XMSS signature scheme for a given config. - - This class holds the configuration and component instances needed to - perform key generation, signing, and verification operations. - """ - - config: XmssConfig - """Configuration parameters for the XMSS scheme.""" - - prf: Prf - """Pseudorandom function for deriving secret values.""" - hasher: TweakHasher - """Hash function with tweakable domain separation.""" - encoder: TargetSumEncoder - """Message encoder that produces valid codewords.""" +def _expand_activation_time( + log_lifetime: int, desired_activation_slot: int, desired_num_active_slots: int +) -> tuple[int, int]: + """Align a requested activation interval to top-bottom tree boundaries. - rand: Rand - """Random data generator for key generation.""" + Phase 1: round start down to a multiple of sqrt(LIFETIME). + Phase 2: round end up to a multiple of sqrt(LIFETIME). + Phase 3: enforce a minimum duration of two bottom trees. + Phase 4: clamp to the lifetime bound, shifting the interval if needed. - def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: - """ - Generates a new cryptographic key pair for a specified range of slots. + Args: + log_lifetime: Base-2 logarithm of the lifetime. + desired_activation_slot: First slot requested. + desired_num_active_slots: Number of slots requested. - This is a **randomized** algorithm that establishes a signer's identity using - the memory-efficient Top-Bottom Tree Traversal approach. - - ### Key Generation Algorithm + Returns: + The pair (start_bottom_tree_index, end_bottom_tree_index). + Actual slots covered are [start * C, end * C) where C = sqrt(LIFETIME). + """ + # C = sqrt(LIFETIME). + # c_mask rounds down to multiples of C. + c = 1 << (log_lifetime // 2) + c_mask = ~(c - 1) + + desired_end_slot = desired_activation_slot + desired_num_active_slots + + # Phase 1 + 2: snap the interval endpoints onto bottom-tree boundaries. + start = desired_activation_slot & c_mask + end = (desired_end_slot + c - 1) & c_mask + + # Phase 3: at least two bottom trees so the prepared window always fits. + if end - start < 2 * c: + end = start + 2 * c + + # Phase 4: clamp to [0, LIFETIME). + lifetime = c * c + if end > lifetime: + duration = end - start + if duration > lifetime: + # The requested interval is wider than the lifetime. + # Use the whole lifetime. + start = 0 + end = lifetime + else: + # Shift the interval back so it ends exactly at the lifetime boundary. + end = lifetime + start = (lifetime - duration) & c_mask - 1. **Expand Activation Time**: Align the requested activation interval to - `sqrt(LIFETIME)` boundaries to enable efficient tree partitioning. - This ensures the interval starts at a multiple of `sqrt(LIFETIME)` and - has a minimum duration of `2 * sqrt(LIFETIME)` slots. + return (start // c, end // c) - 2. **Generate Master Secrets**: Generate PRF key and public parameter `P`. - The PRF key allows deterministic on-demand regeneration of one-time keys. - 3. **Generate First Two Bottom Trees**: Create the first two bottom trees - (covering the initial `2 * sqrt(LIFETIME)` slots) and keep them in memory. - Each bottom tree covers `sqrt(LIFETIME)` consecutive slots. +class GeneralizedXmssScheme(StrictBaseModel): + """Generalized XMSS signature scheme bound to one configuration.""" - 4. **Generate Remaining Bottom Tree Roots**: For all other bottom trees in - the range, generate only their roots (not the full trees). This saves - memory since we only need the first two trees for the prepared window. + config: XmssConfig + """Configuration parameters for this instance.""" - 5. **Build Top Tree**: Construct the top tree from all bottom tree roots. - The top tree's lowest layer contains the bottom tree roots, and it is - built upward to the global Merkle root. + poseidon: PoseidonXmss + """Cached Poseidon1 engine used by every primitive in the scheme.""" - ### Memory Efficiency + def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: + """Generate a fresh key pair active for an aligned slot range. - Traditional approach: O(LIFETIME) memory - Top-Bottom approach: O(sqrt(LIFETIME)) memory + Phase 1: align the requested interval to sqrt(LIFETIME) boundaries. + Phase 2: draw the master PRF key and public parameter. + Phase 3: materialize the two leftmost bottom trees. + Phase 4: generate every other bottom tree, retaining only its root. + Phase 5: build the top tree from all bottom-tree roots. - For LOG_LIFETIME=32 (2^32 slots): - - Traditional: ~hundreds of GiB - - Top-Bottom: much more reasonable + The returned key may cover a wider interval than requested because + of boundary alignment and the two-tree minimum window. Args: - activation_slot: The starting slot for which this key is valid. - - Will be aligned downward to `sqrt(LIFETIME)` boundary. - num_active_slots: The number of consecutive slots the key can be used for. - - Will be rounded up to at least `2 * sqrt(LIFETIME)`. + activation_slot: Requested first signable slot. + num_active_slots: Requested number of signable slots. Returns: - A `KeyPair` containing the public and secret keys. + A KeyPair with both halves of the scheme. - Note: - The actual activation slot and num_active_slots in the returned SecretKey - may be larger than requested due to alignment requirements. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + Raises: + ValueError: When the requested range exceeds the lifetime. """ - # Retrieve the scheme's configuration parameters. config = self.config - # Ensure the requested activation range is within the scheme's total supported lifetime. + # The requested range must fit within the global lifetime. if int(activation_slot) + int(num_active_slots) > int(config.LIFETIME): raise ValueError("Activation range exceeds the key's lifetime.") - # Generate the random public parameter `P` and the master PRF key. - # - `P` ensures hash function outputs are unique to this key pair. - # - PRF key is the single master secret from which all one-time keys are derived. - parameter = self.rand.parameter() - prf_key = self.prf.key_gen() + # Phase 2: draw the master secret and the public parameter. + parameter = random_parameter(config) + prf_key = prf_key_gen() - # Step 1: Expand and align activation time to sqrt(LIFETIME) boundaries. - start_bottom_tree_index, end_bottom_tree_index = expand_activation_time( + # Phase 1: align onto bottom-tree boundaries. + start_bottom_tree_index, end_bottom_tree_index = _expand_activation_time( config.LOG_LIFETIME, int(activation_slot), int(num_active_slots) ) - - num_bottom_trees = end_bottom_tree_index - start_bottom_tree_index leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - - # Calculate the actual (expanded) activation slot and count. actual_activation_slot = start_bottom_tree_index * leaves_per_bottom_tree - actual_num_active_slots = num_bottom_trees * leaves_per_bottom_tree + actual_num_active_slots = ( + end_bottom_tree_index - start_bottom_tree_index + ) * leaves_per_bottom_tree - # Step 2: Generate the first two bottom trees (kept in memory). + # Phase 3: build the two leftmost bottom trees and keep them resident. left_bottom_tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=config, prf_key=prf_key, bottom_tree_index=Uint64(start_bottom_tree_index), parameter=parameter, ) right_bottom_tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=config, prf_key=prf_key, bottom_tree_index=Uint64(start_bottom_tree_index + 1), parameter=parameter, ) - # Collect roots for building the top tree. bottom_tree_roots: list[HashDigestVector] = [ left_bottom_tree.root(), right_bottom_tree.root(), ] - # Step 3: Generate remaining bottom trees (only their roots). + # Phase 4: build every other bottom tree and remember only its root. for i in range(start_bottom_tree_index + 2, end_bottom_tree_index): tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=config, prf_key=prf_key, bottom_tree_index=Uint64(i), @@ -177,21 +147,17 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: ) bottom_tree_roots.append(tree.root()) - # Step 4: Build the top tree from bottom tree roots. + # Phase 5: assemble the top tree from all bottom-tree roots. top_tree = HashSubTree.new_top_tree( - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, + config=config, depth=config.LOG_LIFETIME, start_bottom_tree_index=Uint64(start_bottom_tree_index), parameter=parameter, bottom_tree_roots=bottom_tree_roots, ) - # Extract the global root. - root = top_tree.root() - - # Assemble and return the keys. - pk = PublicKey(root=root, parameter=parameter) + pk = PublicKey(root=top_tree.root(), parameter=parameter) sk = SecretKey( prf_key=prf_key, parameter=parameter, @@ -205,64 +171,39 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: return KeyPair(public_key=pk, secret_key=sk) def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: - """ - Produces a digital signature for a given message at a specific slot. - - This is a **deterministic** algorithm. Calling `sign` twice with the same - (sk, slot, message) triple produces the same signature. + """Produce a signature for a message at a specific slot. - **CRITICAL SECURITY WARNING**: A secret key for a given slot must **NEVER** be used - to sign two different messages. Doing so would reveal parts of the secret key - and allow an attacker to forge signatures. This is the fundamental security - property of a synchronized (stateful) signature scheme. + Phase 1: enforce that the slot is inside the activation and prepared windows. + Phase 2: search for randomness rho whose encoding lands on the target-sum layer. + Phase 3: walk each Winternitz chain to the released hash dictated by the codeword. + Phase 4: build the combined Merkle path through the bottom and top trees. - ### Signing Algorithm - - 1. **Message Encoding with Randomness (`rho`)**: The "Target Sum" scheme - requires the message hash to be encoded into a `codeword` whose digits - sum to a predefined target. A direct hash of the message is unlikely to - satisfy this. Therefore, the algorithm repeatedly hashes the message - combined with deterministic randomness (`rho`) derived from the PRF - until a valid `codeword` is found. - - 2. **One-Time Signature**: The `codeword` dictates how the one-time signature is - formed. For each digit `x_i` in the codeword, the signer reveals an intermediate - hash value by applying the hash function `x_i` times to the secret start of the - `i`-th hash chain. - The collection of these intermediate hashes forms the one-time signature. - - 3. **Merkle Path**: The signer retrieves the Merkle authentication path for the leaf - corresponding to the current `slot`. This path proves that the one-time public key - for this slot is part of the main public key (the Merkle root). + Signing is deterministic in (sk, slot, message). + A secret key must never sign two different messages for the same slot. Args: - sk: The secret key to use for signing. - slot: The slot for which the signature is being created. - message: The message to be signed. + sk: Secret key. + slot: Signing slot. + message: Message to sign. Returns: - The resulting `Signature` object. + A signature carrying the OTS, the Merkle path, and the randomness. - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + Raises: + ValueError: When the slot is outside the activation or prepared window. + RuntimeError: When no valid encoding is found within MAX_TRIES attempts. """ - # Retrieve the scheme's configuration parameters. config = self.config - - # Verify that the secret key is currently active for the requested signing slot. slot_int = int(slot) activation_int = int(sk.activation_slot) + + # Phase 1a: activation bound. if not (activation_int <= slot_int < activation_int + int(sk.num_active_slots)): raise ValueError("Key is not active for the specified slot.") - # Verify that the slot is within the prepared interval (covered by loaded bottom trees). - # - # With top-bottom tree traversal, only slots within the prepared interval can be - # signed without computing additional bottom trees. - # - # If the slot is outside this range, we need to slide the window forward. + # Phase 1b: prepared bound. + # Without two adjacent bottom trees we cannot produce a path without + # paying the cost of regenerating them on the fly. leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) prepared_start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree prepared_end = prepared_start + 2 * leaves_per_bottom_tree @@ -273,45 +214,28 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: f"Call advance_preparation() to slide the window forward." ) - # Find a valid message encoding. - # - # This loop repeatedly tries different randomness `rho` until the encoder - # produces a valid codeword (i.e., one that meets the target sum constraint). - # - # The randomness is deterministically derived from the PRF to ensure - # that signing is reproducible for the same (sk, slot, message). + # Phase 2: deterministic search for valid randomness. + # Randomness comes from the PRF so signing is reproducible. for attempts in range(config.MAX_TRIES): - # Derive deterministic randomness `rho` from PRF using the attempt counter. - rho = self.prf.get_randomness(sk.prf_key, slot, message, Uint64(attempts)) - # Attempt to encode the message with the deterministic `rho`. - codeword = self.encoder.encode(sk.parameter, message, rho, slot) - # If encoding is successful, we've found our `rho` and `codeword`. - # - # We can exit the loop. + rho = prf_get_randomness(config, sk.prf_key, slot, message, Uint64(attempts)) + codeword = target_sum_encode(self.poseidon, config, sk.parameter, message, rho, slot) if codeword is not None: break else: - # This block executes only if the `for` loop completes without a `break`. - # - # This means that no valid encoding was found after the maximum number of tries. raise RuntimeError( f"Failed to find a valid message encoding after {config.MAX_TRIES} tries." ) - # Sanity check to ensure the encoder returned a codeword of the correct length. - if len(codeword) != self.config.DIMENSION: + # Sanity guard against an encoder returning the wrong number of digits. + if len(codeword) != config.DIMENSION: raise RuntimeError("Encoding is broken: returned too many or too few chunks.") - # Compute the one-time signature hashes based on the codeword. + # Phase 3: walk each Winternitz chain to the released hash. ots_hashes: list[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): - # Derive the secret start of the current chain using the master PRF key. - start_digest = self.prf.apply(sk.prf_key, slot, Uint64(chain_index)) - # Walk the hash chain for the number of `steps` specified by the - # corresponding digit in the codeword. - # - # The result is one component of the OTS. - ots_digest = self.hasher.hash_chain( + start_digest = prf_apply(config, sk.prf_key, slot, Uint64(chain_index)) + ots_digest = self.poseidon.hash_chain( + config=config, parameter=sk.parameter, epoch=slot, chain_index=chain_index, @@ -321,106 +245,67 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: ) ots_hashes.append(ots_digest) - # Retrieve the Merkle authentication path for the current slot's leaf. - # With top-bottom tree traversal, we use combined_path to merge paths from - # the bottom tree and top tree. - - # Determine which bottom tree contains this slot (reuse leaves_per_bottom_tree from above). + # Phase 4: combined Merkle path through both trees. + # The signed slot picks the bottom tree on the prepared window's left or right. boundary = (int(sk.left_bottom_tree_index) + 1) * leaves_per_bottom_tree bottom_tree = sk.left_bottom_tree if slot_int < boundary else sk.right_bottom_tree - - # Generate the combined authentication path path = combined_path(sk.top_tree, bottom_tree, Uint64(int(slot))) - # Assemble and return the final signature, which contains: - # - The OTS, - # - The Merkle path, - # - The randomness `rho` needed for verification. return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> bool: - r""" - Verifies a digital signature against a public key, message, and slot. - - This is a **deterministic** algorithm. - - ### Verification Algorithm - - 1. **Re-encode Message**: The verifier uses the randomness `rho` from the - signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the message `m`. - If the encoding is invalid (e.g., does not meet the target sum), verification fails. + """Verify a signature against a public key, message, and slot. - 2. **Reconstruct One-Time Public Key**: For each intermediate hash $y_i$ in the - signature's `hashes` field, the verifier completes the corresponding hash chain. - Since $y_i$ was computed by hashing $x_i$ times, the verifier applies the - hash function an additional `BASE - 1 - x_i` times to arrive at the - chain's public endpoint, which is one component of the one-time public key. - - 3. **Compute Merkle Leaf**: The verifier hashes the full set of reconstructed - chain endpoints to compute the expected Merkle leaf for the given slot. - - 4. **Verify Merkle Path**: The verifier uses the authentication `path` from the - signature to compute a candidate Merkle root, starting from the leaf computed - in the previous step. Verification succeeds if and only if this candidate root - matches the `root` stored in the `PublicKey`. + Phase 1: bound-check the slot. + Phase 1 rejects without raising on bad input. + Phase 2: recompute the codeword using the randomness carried by the signature. + Phase 3: complete each Winternitz chain from the released hash to its endpoint. + Phase 4: rebuild the Merkle root from the chain endpoints and the opening. Args: - pk: The public key to verify against. - slot: The slot the signature corresponds to. - message: The message that was supposedly signed. - sig: The signature object to be verified. + pk: Public key. + slot: Signing slot claimed by the signature. + message: Message claimed by the signature. + sig: Signature to verify. Returns: - `True` if the signature is valid, `False` otherwise. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + True when the signature is valid against the public key, false otherwise. """ - # Retrieve the scheme's configuration parameters. config = self.config - # Validate slot bounds. - # - # Return False instead of raising to avoid panic on invalid signatures. - # The slot is attacker-controlled input. - if int(slot) >= int(self.config.LIFETIME): + # Phase 1: bound check on the slot. + # The slot is attacker-controlled, so a malformed value returns False + # rather than panicking deep in the verification routine. + if int(slot) >= int(config.LIFETIME): return False - # Re-encode the message using the randomness `rho` from the signature. - # - # If the encoding is invalid (e.g., fails the target sum check), the signature is invalid. - codeword = self.encoder.encode(pk.parameter, message, sig.rho, slot) + # Phase 2: rederive the codeword from the signature's randomness. + # A failing aborting decode means the signature cannot be valid. + codeword = target_sum_encode(self.poseidon, config, pk.parameter, message, sig.rho, slot) if codeword is None: return False - # Reconstruct the one-time public key (the list of chain endpoints). + # Phase 3: finish each chain from the released hash to its endpoint. chain_ends: list[HashDigestVector] = [] for chain_index, xi in enumerate(codeword): - # The signature provides `start_digest`, which is the hash value after `xi` steps. + # The signature provides the digest after xi steps along the chain. + # We hash the remaining BASE - 1 - xi times to reach the endpoint. start_digest = sig.hashes[chain_index] - # We must perform the remaining `BASE - 1 - xi` hashing steps - # to compute the public endpoint of the chain. - num_steps_remaining = config.BASE - 1 - xi - end_digest = self.hasher.hash_chain( + end_digest = self.poseidon.hash_chain( + config=config, parameter=pk.parameter, epoch=slot, chain_index=chain_index, start_step=xi, - num_steps=num_steps_remaining, + num_steps=config.BASE - 1 - xi, start_digest=start_digest, ) chain_ends.append(end_digest) - # Verify the Merkle path. - # - # This function internally: - # - Hashes the `chain_ends` to get the leaf node for the slot, - # - Uses the `opening` path from the signature to compute a candidate root. - # - It returns true if and only if this candidate root matches the public key's root. + # Phase 4: rebuild and compare against the trusted root. return verify_path( - hasher=self.hasher, + poseidon=self.poseidon, + config=config, parameter=pk.parameter, root=pk.root, position=slot, @@ -429,94 +314,58 @@ def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> ) def get_activation_interval(self, sk: SecretKey) -> range: - """ - Returns the slot range for which this secret key is active. - - The activation interval is `[activation_slot, activation_slot + num_active_slots)`. - A signature can only be created for a slot within this range. - - Args: - sk: The secret key to query. + """Return the activation interval as a Python range. - Returns: - A Python range object representing the valid slot range. + A signature is only valid for a slot inside this range. """ start = int(sk.activation_slot) - end = start + int(sk.num_active_slots) - return range(start, end) + return range(start, start + int(sk.num_active_slots)) def get_prepared_interval(self, sk: SecretKey) -> range: - """ - Returns the slot range currently prepared (covered by loaded bottom trees). - - With top-bottom tree traversal, a secret key maintains a sliding window of - two consecutive bottom trees. This method returns the range of slots that - can be signed with the currently loaded trees, without needing to compute - additional bottom trees. - - The prepared interval is: - `[left_bottom_tree_index * sqrt(LIFETIME), (left_bottom_tree_index + 2) * sqrt(LIFETIME))` - - Args: - sk: The secret key to query. - - Returns: - A Python range object representing the prepared slot range. + """Return the prepared interval as a Python range. - Raises: - ValueError: If the secret key is missing top-bottom tree structures. + The prepared interval is the slot window covered by the two resident bottom trees. + A signer can sign any slot in this range without paying the cost of + rebuilding a bottom tree from the PRF. """ leaves_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree return range(start, start + 2 * leaves_per_bottom_tree) def advance_preparation(self, sk: SecretKey) -> SecretKey: - """ - Advances the prepared interval by computing the next bottom tree. - - This method implements the "sliding window" strategy for top-bottom tree - traversal. It: - 1. Computes a new bottom tree for the next interval - 2. Shifts the current right tree to become the new left tree - 3. The newly computed tree becomes the new right tree - 4. Increments `left_bottom_tree_index` + """Slide the prepared window one bottom tree forward. - After this operation, the prepared interval moves forward by `sqrt(LIFETIME)` slots. + Phase 1: bail out when the next window would exceed the activation interval. + Phase 2: regenerate the new right bottom tree from the PRF key. + Phase 3: shift the previous right tree to the left slot. - **When to call**: Call this method after signing with a slot that is in the - right half of the prepared interval, to ensure the next slot range is ready. + Returning the same key when no advancement is possible keeps callers simple. Args: - sk: The secret key to advance. + sk: Secret key whose prepared window should advance. Returns: - A new SecretKey with the advanced preparation window. - - Raises: - ValueError: If advancing would exceed the activation interval. + A secret key with the window shifted by one bottom tree. """ leaves_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) left_index = int(sk.left_bottom_tree_index) - # Check if advancing would exceed the activation interval + # Phase 1: no advancement once the activation interval is fully consumed. next_prepared_end_slot = (left_index + 3) * leaves_per_bottom_tree activation_end = int(sk.activation_slot) + int(sk.num_active_slots) if next_prepared_end_slot > activation_end: - # Nothing to do - we're already at the end of the activation interval return sk - # Compute the next bottom tree (the one after the current right tree) + # Phase 2: rebuild the next bottom tree from the master PRF key. new_right_bottom_tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=self.config, prf_key=sk.prf_key, bottom_tree_index=Uint64(left_index + 2), parameter=sk.parameter, ) - # Return a new SecretKey with the advanced window + # Phase 3: rotate the right tree into the left slot, advance the index. return sk.model_copy( update={ "left_bottom_tree": sk.right_bottom_tree, @@ -526,28 +375,11 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: ) -PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme( - config=PROD_CONFIG, - prf=PROD_PRF, - hasher=PROD_TWEAK_HASHER, - encoder=PROD_TARGET_SUM_ENCODER, - rand=PROD_RAND, -) -"""An instance configured for production-level parameters.""" - -TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme( - config=TEST_CONFIG, - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - encoder=TEST_TARGET_SUM_ENCODER, - rand=TEST_RAND, -) -"""A lightweight instance for test environments.""" - -_LEAN_ENV_TO_SCHEME = { - "test": TEST_SIGNATURE_SCHEME, - "prod": PROD_SIGNATURE_SCHEME, -} - -TARGET_SIGNATURE_SCHEME = _LEAN_ENV_TO_SCHEME[LEAN_ENV] -"""The active XMSS signature scheme based on LEAN_ENV environment variable.""" +PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme(config=PROD_CONFIG, poseidon=PROD_POSEIDON) +"""Signature scheme instance with production parameters.""" + +TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme(config=TEST_CONFIG, poseidon=TEST_POSEIDON) +"""Signature scheme instance with test parameters.""" + +TARGET_SIGNATURE_SCHEME = TEST_SIGNATURE_SCHEME if LEAN_ENV == "test" else PROD_SIGNATURE_SCHEME +"""Active scheme selected at import time from the LEAN_ENV environment variable.""" diff --git a/src/lean_spec/subspecs/xmss/merkle.py b/src/lean_spec/subspecs/xmss/merkle.py new file mode 100644 index 000000000..eb924e2cf --- /dev/null +++ b/src/lean_spec/subspecs/xmss/merkle.py @@ -0,0 +1,480 @@ +"""Sparse Merkle subtrees for top-bottom XMSS traversal. + +The XMSS lifetime tree is split into one top tree and many bottom trees. +Each bottom tree covers sqrt(LIFETIME) consecutive slots. +The signer keeps the full top tree plus two adjacent bottom trees resident, +forming a sliding window of 2*sqrt(LIFETIME) signable slots. + +This bounds the secret-key memory at O(sqrt(LIFETIME)) instead of O(LIFETIME). +""" + +from itertools import batched +from typing import Self + +from lean_spec.types import Uint64 +from lean_spec.types.container import Container + +from .constants import XmssConfig +from .field import random_domain +from .poseidon import PoseidonXmss +from .prf import prf_apply +from .types import ( + HashDigestList, + HashDigestVector, + HashTreeLayer, + HashTreeLayers, + HashTreeOpening, + Parameter, + PRFKey, + TreeTweak, +) + + +def _padded_layer( + config: XmssConfig, + nodes: list[HashDigestVector], + start_index: Uint64, +) -> HashTreeLayer: + """Pad a layer so every node has a sibling and parent generation has no edge cases. + + Invariant: the padded layer starts at an even index and ends at an odd index. + A single-node layer is allowed when the layer is the root. + """ + nodes_with_padding: list[HashDigestVector] = [] + end_index = start_index + Uint64(len(nodes)) - Uint64(1) + + # Prepend one random sibling when the layer begins on an odd index. + if start_index % Uint64(2) == Uint64(1): + nodes_with_padding.append(random_domain(config)) + + # The padded layer always starts on the even index at or before start_index. + actual_start_index = start_index - (start_index % Uint64(2)) + + nodes_with_padding.extend(nodes) + + # Append one random sibling when the layer ends on an even index. + if end_index % Uint64(2) == Uint64(0): + nodes_with_padding.append(random_domain(config)) + + return HashTreeLayer( + start_index=actual_start_index, + nodes=HashDigestList(data=nodes_with_padding), + ) + + +class HashSubTree(Container): + """Sparse Merkle subtree of an XMSS lifetime tree. + + Stores layers from lowest_layer up to the subtree root. + A bottom tree has lowest_layer = 0 and covers a window of leaves. + A top tree has lowest_layer = LOG_LIFETIME/2 and covers the bottom-tree roots. + + Layout invariant: every active layer starts on an even index and ends on + an odd index except for the single-node root layer. + """ + + depth: Uint64 + """Depth of the full lifetime tree this subtree belongs to. + A subtree starting at layer k stores depth - k layers.""" + + lowest_layer: Uint64 + """Lowest layer included in this subtree. + Zero for bottom trees, LOG_LIFETIME/2 for top trees.""" + + layers: HashTreeLayers + """Layers stored from lowest_layer up to the subtree root. + The last entry holds a single node, the subtree root.""" + + @classmethod + def new( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + lowest_layer: Uint64, + depth: Uint64, + start_index: Uint64, + parameter: Parameter, + lowest_layer_nodes: list[HashDigestVector], + ) -> Self: + """Build a subtree from its lowest layer up to the root. + + Phase 1: pad the input layer to the alignment invariant. + Phase 2: hash each sibling pair to produce the next layer up. + Phase 3: pad each new layer and continue to the root. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + lowest_layer: Starting layer for this subtree. + depth: Total depth of the full lifetime tree. + start_index: Absolute index of the first input node. + parameter: Public parameter for the hash function. + lowest_layer_nodes: Active nodes at the lowest layer. + + Returns: + A subtree containing every layer from lowest_layer to the root. + """ + # The input nodes must fit in the layer they belong to. + max_positions = 1 << int(depth - lowest_layer) + if int(start_index) + len(lowest_layer_nodes) > max_positions: + raise ValueError( + f"Overflow at layer {lowest_layer}: " + f"start={start_index}, count={len(lowest_layer_nodes)}, max={max_positions}" + ) + + # Phase 1: pad the input layer. + layers: list[HashTreeLayer] = [] + current = _padded_layer(config, lowest_layer_nodes, start_index) + layers.append(current) + + # Phases 2 + 3: hash sibling pairs, pad, repeat. + for level in range(lowest_layer, depth): + parent_start = current.start_index // Uint64(2) + parents = [ + poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=level + 1, index=parent_start + Uint64(i)), + [left, right], + ) + for i, (left, right) in enumerate(batched(current.nodes, 2)) + ] + current = _padded_layer(config, parents, parent_start) + layers.append(current) + + return cls( + depth=depth, + lowest_layer=lowest_layer, + layers=HashTreeLayers(data=layers), + ) + + @classmethod + def new_top_tree( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + depth: int, + start_bottom_tree_index: Uint64, + parameter: Parameter, + bottom_tree_roots: list[HashDigestVector], + ) -> Self: + """Build the top tree from bottom-tree roots up to the global root. + + The top tree starts at layer depth/2 and treats bottom-tree roots as its leaves. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + depth: Total depth of the full lifetime tree. + start_bottom_tree_index: Index of the first bottom tree in the range. + parameter: Public parameter for the hash function. + bottom_tree_roots: Roots of all bottom trees in the range, in order. + + Returns: + A top tree whose root is the global Merkle root. + + Raises: + ValueError: When depth is odd. + """ + # Top-bottom split requires an even depth. + if depth % 2 != 0: + raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") + + return cls.new( + poseidon=poseidon, + config=config, + lowest_layer=Uint64(depth // 2), + depth=Uint64(depth), + start_index=start_bottom_tree_index, + parameter=parameter, + lowest_layer_nodes=bottom_tree_roots, + ) + + @classmethod + def new_bottom_tree( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + depth: int, + bottom_tree_index: Uint64, + parameter: Parameter, + leaves: list[HashDigestVector], + ) -> Self: + """Build one bottom tree from leaf hashes up to its standalone root. + + Phase 1: build a full subtree from layer 0 using the provided leaves. + Phase 2: drop the layers above depth/2 produced by extra padding. + Phase 3: replace the highest layer with a single-node root extracted from middle. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + depth: Total depth of the full lifetime tree. + bottom_tree_index: Index of this bottom tree. + parameter: Public parameter for the hash function. + leaves: Pre-hashed one-time public keys for this bottom tree's slots. + + Returns: + A subtree with layers 0 through depth/2 ending in the bottom-tree root. + + Raises: + ValueError: When depth is odd or the leaf count does not match sqrt(LIFETIME). + """ + if depth % 2 != 0: + raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") + + # Each bottom tree spans exactly sqrt(LIFETIME) leaves. + leaves_per_tree = 1 << (depth // 2) + if len(leaves) != leaves_per_tree: + raise ValueError( + f"Expected {leaves_per_tree} leaves for depth={depth}, got {len(leaves)}." + ) + + # Phase 1: build a full subtree from layer 0. + full_tree = cls.new( + poseidon=poseidon, + config=config, + lowest_layer=Uint64(0), + depth=Uint64(depth), + start_index=bottom_tree_index * Uint64(leaves_per_tree), + parameter=parameter, + lowest_layer_nodes=leaves, + ) + + # Phase 3: extract the middle layer's root entry for this bottom tree. + middle = full_tree.layers[depth // 2] + root_idx = int(bottom_tree_index - middle.start_index) + root_layer = HashTreeLayer( + start_index=bottom_tree_index, + nodes=HashDigestList(data=[middle.nodes[root_idx]]), + ) + + # Phase 2 + 3: keep layers 0 through depth/2 - 1, then append the standalone root. + truncated = list(full_tree.layers[: depth // 2]) + return cls( + depth=Uint64(depth), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=truncated + [root_layer]), + ) + + @classmethod + def from_prf_key( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + prf_key: PRFKey, + bottom_tree_index: Uint64, + parameter: Parameter, + ) -> Self: + """Regenerate one bottom tree on demand from the master PRF key. + + Phase 1: for every epoch in the bottom tree, derive chain starts via PRF. + Phase 2: hash each chain for BASE - 1 steps to obtain the chain endpoints. + Phase 3: hash chain endpoints into a leaf, then build the bottom tree. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + prf_key: Master secret seed. + bottom_tree_index: Index of the bottom tree to regenerate. + parameter: Public parameter for the hash function. + + Returns: + The requested bottom tree. + """ + # Each bottom tree covers sqrt(LIFETIME) consecutive epochs. + leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) + start_epoch = bottom_tree_index * Uint64(leaves_per_bottom_tree) + end_epoch = start_epoch + Uint64(leaves_per_bottom_tree) + + leaf_hashes: list[HashDigestVector] = [] + for epoch in range(start_epoch, end_epoch): + # Phases 1 + 2: derive each chain start, then walk it to the public endpoint. + chain_ends: list[HashDigestVector] = [] + for chain_index in range(config.DIMENSION): + start_digest = prf_apply(config, prf_key, Uint64(epoch), Uint64(chain_index)) + end_digest = poseidon.hash_chain( + config=config, + parameter=parameter, + epoch=Uint64(epoch), + chain_index=chain_index, + start_step=0, + num_steps=config.BASE - 1, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # Phase 3: hash all chain endpoints into the leaf for this epoch. + leaf_tweak = TreeTweak(level=0, index=Uint64(epoch)) + leaf_hash = poseidon.tweak_hash(config, parameter, leaf_tweak, chain_ends) + leaf_hashes.append(leaf_hash) + + return cls.new_bottom_tree( + poseidon=poseidon, + config=config, + depth=config.LOG_LIFETIME, + bottom_tree_index=bottom_tree_index, + parameter=parameter, + leaves=leaf_hashes, + ) + + def root(self) -> HashDigestVector: + """Return the single node in the highest stored layer. + + Raises: + ValueError: When the subtree is empty or the highest layer has no nodes. + """ + if not self.layers: + raise ValueError("Empty subtree has no root.") + if not self.layers[-1].nodes: + raise ValueError("Top layer is empty.") + return self.layers[-1].nodes[0] + + def path(self, position: Uint64) -> HashTreeOpening: + """Build the authentication path from a leaf up to the subtree root. + + For a subtree covering layers L through H, the opening contains H - L siblings, + one per layer between L and H - 1. + + Args: + position: Absolute index of the leaf in the full tree coordinate system. + + Returns: + An opening of sibling hashes from bottom to top. + + Raises: + ValueError: When the subtree is empty or the position is out of bounds. + """ + if not self.layers: + raise ValueError("Empty subtree.") + + first = self.layers[0] + if not (first.start_index <= position < first.start_index + Uint64(len(first.nodes))): + raise ValueError(f"Position {position} out of bounds.") + + siblings: list[HashDigestVector] = [] + pos = position + + # Stop one short of the root layer. + # The root has no sibling. + for layer in self.layers[:-1]: + # The sibling sits at the position with the last bit flipped, then we + # rebase by the layer's start_index because the layer is sparse. + sibling_idx = int((pos ^ Uint64(1)) - layer.start_index) + if not (0 <= sibling_idx < len(layer.nodes)): + raise ValueError(f"Sibling index {sibling_idx} out of bounds.") + siblings.append(layer.nodes[sibling_idx]) + pos = pos // Uint64(2) + + return HashTreeOpening(siblings=HashDigestList(data=siblings)) + + +def combined_path( + top_tree: HashSubTree, + bottom_tree: HashSubTree, + position: Uint64, +) -> HashTreeOpening: + """Concatenate the bottom-tree and top-tree openings for one leaf. + + A signature must authenticate the leaf against the global root. + The bottom opening proves leaf membership in its bottom tree. + The top opening proves the bottom-tree root sits under the global root. + + Args: + top_tree: The top tree containing the global root. + bottom_tree: The bottom tree containing the leaf. + position: Absolute index of the leaf. + + Returns: + An opening with depth siblings authenticating the leaf against the global root. + + Raises: + ValueError: When tree depths mismatch, depth is odd, or position is out + of bounds for the supplied bottom tree. + """ + if top_tree.depth != bottom_tree.depth: + raise ValueError(f"Depth mismatch: top={top_tree.depth}, bottom={bottom_tree.depth}.") + + depth = int(top_tree.depth) + if depth % 2 != 0: + raise ValueError(f"Depth must be even, got {depth}.") + + # The position must belong to the supplied bottom tree, not a sibling one. + leaves_per_tree = Uint64(1 << (depth // 2)) + expected_start = (position // leaves_per_tree) * leaves_per_tree + if bottom_tree.layers[0].start_index != expected_start: + raise ValueError( + f"Wrong bottom tree: position {position} needs start {expected_start}, " + f"got {bottom_tree.layers[0].start_index}." + ) + + # Bottom path proves leaf -> bottom-tree root. + # Top path proves bottom root -> global root. + bottom_path = bottom_tree.path(position) + top_path = top_tree.path(position // leaves_per_tree) + combined = tuple(bottom_path.siblings.data) + tuple(top_path.siblings.data) + + return HashTreeOpening(siblings=HashDigestList(data=combined)) + + +def verify_path( + poseidon: PoseidonXmss, + config: XmssConfig, + parameter: Parameter, + root: HashDigestVector, + position: Uint64, + leaf_parts: list[HashDigestVector], + opening: HashTreeOpening, +) -> bool: + """Verify a Merkle opening against a trusted root. + + Phase 1: hash leaf_parts into the leaf digest. + Phase 2: walk the opening, hashing the current node with each sibling. + Phase 3: compare the reconstructed root with the trusted one. + + Returns False on attacker-controlled invalid input instead of raising. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + parameter: Public parameter for the hash function. + root: Trusted root taken from the public key. + position: Absolute index of the leaf being verified. + leaf_parts: Digests that constitute the original leaf. + opening: Sibling path from leaf to root. + + Returns: + True when the path reconstructs the root, False otherwise. + """ + # Guard against malformed openings. + # The opening list caps at 32 entries. + # A depth greater than 32 would overflow the position bound check below. + depth = len(opening.siblings) + if depth > 32: + return False + if int(position) >= (1 << depth): + return False + + # Phase 1: hash the leaf parts to derive the starting node. + current = poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=0, index=Uint64(position)), + leaf_parts, + ) + pos = int(position) + + # Phase 2: hash with each sibling, climbing one layer per iteration. + for level, sibling in enumerate(opening.siblings): + # The current node sits on the left when its position is even. + left, right = (current, sibling) if pos % 2 == 0 else (sibling, current) + pos //= 2 + current = poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=level + 1, index=Uint64(pos)), + [left, right], + ) + + # Phase 3: compare against the trusted root. + return current == root diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py deleted file mode 100644 index f1d253750..000000000 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Defines the message hashing for the signature scheme using aborting hypercube encoding. - -### The Challenge: Efficiently Encoding a Message as a Codeword - -The "Target Sum" signature scheme requires the signer to find a `codeword` whose -digits sum to a specific value. This requires hashing a message and mapping the -output to a vertex in a high-dimensional hypercube. - -### The Solution: Aborting Hypercube Encoding - -This module implements a circuit-friendly encoding based on rejection sampling of -individual field elements, eliminating all big-integer arithmetic. - -For KoalaBear (`P = 2^31 - 2^24 + 1`), `P - 1 = Q * BASE^Z`, so each field element -can be decomposed into `Z` base-`BASE` digits after dividing by `Q`. The only reject -case is `A_i == P - 1` (probability ~4.7e-10 per FE — essentially never aborts). - -This is backed by the "Aborting Random Oracles" paper which proves -indifferentiability from a theta-aborting random oracle when modeling Poseidon as a -standard random oracle. - -The encoding proceeds in two stages: - -1. **Input Preparation**: All inputs are encoded into field elements. -2. **Poseidon Hashing + Aborting Decode**: Poseidon1 produces `ceil(DIMENSION/Z)` - field elements, each decoded into `Z` base-`BASE` digits via rejection sampling. -""" - -from __future__ import annotations - -from lean_spec.subspecs.xmss.poseidon import ( - PROD_POSEIDON, - TEST_POSEIDON, - PoseidonXmss, -) -from lean_spec.types import Bytes32, StrictBaseModel, Uint64 - -from ..koalabear import Fp -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - TWEAK_PREFIX_MESSAGE, - XmssConfig, -) -from .types import Parameter, Randomness -from .utils import int_to_base_p - - -class MessageHasher(StrictBaseModel): - """An instance of the message hasher using aborting hypercube encoding.""" - - config: XmssConfig - """Configuration parameters for the hasher.""" - - poseidon: PoseidonXmss - """Poseidon hash engine.""" - - def encode_message(self, message: Bytes32) -> list[Fp]: - """ - Encodes a 32-byte message into a list of field elements. - - The message bytes are interpreted as a single little-endian integer, - which is then decomposed into its base-`P` representation, where `P` - is the field prime. This provides a canonical mapping from bytes to - the algebraic structure required by Poseidon1. - """ - # Interpret the 32 little-endian bytes as a single large integer. - acc = int.from_bytes(message, "little") - - # Decompose the integer into a list of field elements (base-P). - return int_to_base_p(acc, self.config.MSG_LEN_FE) - - def encode_epoch(self, epoch: Uint64) -> list[Fp]: - """ - Encodes the epoch and a domain separator prefix into field elements. - - This function packs the epoch and the message hash prefix into a single - integer, then decomposes it. This ensures the epoch is included in the - hash input in a structured, domain-separated way. - """ - # Combine the epoch and the message hash prefix into a single integer. - acc = (int(epoch) << 8) | TWEAK_PREFIX_MESSAGE - - # Decompose the integer into its base-P representation. - return int_to_base_p(acc, self.config.TWEAK_LEN_FE) - - def _aborting_decode(self, field_elements: list[Fp]) -> list[int] | None: - """ - Decodes Poseidon output field elements into base-`BASE` digits via rejection sampling. - - For each field element `A_i`: - - 1. If `A_i >= Q * BASE^Z` (i.e. `A_i == P - 1`), abort and return `None`. - 2. Compute `d_i = A_i // Q`, an integer in `[0, BASE^Z - 1]`. - 3. Decompose `d_i` into `Z` base-`BASE` digits, least significant first. - - Collect all digits and return the first `DIMENSION` of them. - """ - config = self.config - threshold = config.Q * config.BASE**config.Z - - digits: list[int] = [] - for fe in field_elements: - a = int(fe) - - # Rejection: the only failing case is A_i == P - 1. - if a >= threshold: - return None - - # Integer quotient removes the Q-residue, leaving a uniform value in [0, BASE^Z - 1]. - d = a // config.Q - - # Decompose d into Z base-BASE digits, least significant first. - for _ in range(config.Z): - d, digit = divmod(d, config.BASE) - digits.append(digit) - - # Take exactly DIMENSION digits. - return digits[: config.DIMENSION] - - def apply( - self, - parameter: Parameter, - epoch: Uint64, - rho: Randomness, - message: Bytes32, - ) -> list[int] | None: - """ - Applies message hashing followed by aborting hypercube decode. - - Hashes the inputs with Poseidon1 to produce `MH_HASH_LEN_FE` field elements, - then decodes them into a candidate codeword via rejection sampling. - - Args: - parameter: The public parameter `P`. - epoch: The current epoch. - rho: A random value `rho` to ensure a unique hash output. - message: The 32-byte message to be hashed. - - Returns: - A candidate codeword (list of `DIMENSION` digits in `[0, BASE-1]`), - or `None` if the aborting decode rejects. - """ - # Encode the message and epoch as field elements. - message_fe = self.encode_message(message) - epoch_fe = self.encode_epoch(epoch) - - # Call Poseidon1 once to produce the required number of output field elements. - base_input = message_fe + list(parameter.data) + epoch_fe + list(rho.data) - poseidon_output = self.poseidon.compress(base_input, 24, self.config.MH_HASH_LEN_FE) - - # Decode the field elements into base-BASE digits via rejection sampling. - return self._aborting_decode(poseidon_output) - - -PROD_MESSAGE_HASHER = MessageHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) -"""An instance configured for production-level parameters.""" - -TEST_MESSAGE_HASHER = MessageHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 37a6ff7b1..c340a9a2c 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -1,30 +1,15 @@ -""" -Defines the Poseidon1 hash functions for the Generalized XMSS scheme. - -### The Cryptographic Engine: Why Poseidon1? - -This module provides the low-level cryptographic engine for all internal hashing -operations. It is built on **Poseidon1** hash function. - -The choice of Poseidon1 is deliberate and critical for the scheme's ultimate goal. -Unlike traditional hashes like SHA-3, Poseidon1 is an **arithmetization-friendly** -(or **SNARK-friendly**) hash function. Its algebraic structure is simple, making it -exponentially faster to prove and verify inside a zero-knowledge proof system, -which is essential for aggregating many signatures into a single, compact proof. - -This file provides wrappers for the two primary ways Poseidon1 is used: +"""Poseidon1 hash engine in compression and sponge modes for the Generalized XMSS scheme. -1. **Compression Mode**: A fast, fixed-input-size mode for hashing small, - predictable data structures like a single hash digest or a pair of them. -2. **Sponge Mode**: A flexible, variable-input-size mode for hashing large - amounts of data, like the many digests that form a Merkle tree leaf. +Poseidon1 is arithmetization-friendly. +Hashing across hash chains, Merkle nodes, and Merkle leaves uses a single +permutation, which keeps the in-SNARK aggregation step cheap. """ -from __future__ import annotations +from itertools import batched from pydantic import PrivateAttr -from lean_spec.types import StrictBaseModel +from lean_spec.types import StrictBaseModel, Uint64 from ..koalabear import Fp from ..poseidon1.permutation import ( @@ -33,107 +18,93 @@ Poseidon1, Poseidon1Params, ) -from .utils import int_to_base_p +from .constants import TWEAK_PREFIX_CHAIN, TWEAK_PREFIX_TREE, XmssConfig +from .field import int_to_base_p +from .types import ChainTweak, HashDigestVector, Parameter, TreeTweak class PoseidonXmss(StrictBaseModel): - """An instance of the Poseidon1 hash engine for the XMSS scheme.""" + """Poseidon1 hash engine wrapper used inside the XMSS scheme.""" params16: Poseidon1Params - """Poseidon1 parameters for 16-width permutation.""" + """Permutation parameters for the width-16 state.""" params24: Poseidon1Params - """Poseidon1 parameters for 24-width permutation.""" + """Permutation parameters for the width-24 state.""" - _engine16: Poseidon1 | None = PrivateAttr(default=None) - _engine24: Poseidon1 | None = PrivateAttr(default=None) + _engines: dict[int, Poseidon1] = PrivateAttr(default_factory=dict) def _get_engine(self, width: int) -> Poseidon1: - """Return a cached Poseidon1 engine for the given width.""" - if width == 16: - if self._engine16 is None: - self._engine16 = Poseidon1(self.params16) - return self._engine16 - if self._engine24 is None: - self._engine24 = Poseidon1(self.params24) - return self._engine24 + """Return a cached Poseidon1 engine for the given width. - def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: + Raises: + ValueError: When the width is neither 16 nor 24. """ - Implements the Poseidon1 hash in **compression mode**. - - This mode is used for hashing fixed-size inputs and is the most efficient - way to use Poseidon1. It is used for traversing hash chains and building - the internal nodes of the Merkle tree. + if width not in self._engines: + match width: + case 16: + params = self.params16 + case 24: + params = self.params24 + case _: + raise ValueError(f"Width must be 16 or 24, got {width}") + self._engines[width] = Poseidon1(params) + return self._engines[width] - ### Compression Algorithm + def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: + """Poseidon1 in compression mode. - The function computes: `Truncate(Permute(padded_input) + padded_input)`. - 1. **Padding**: The `input_vec` is padded with zeros to match the full state `width`. - 2. **Permutation**: The core cryptographic permutation is applied to the padded state. - 3. **Feed-Forward**: The original padded input is added element-wise to the - permuted state. This is a key feature of the Poseidon1 design that - provides security against certain attacks. - 4. **Truncation**: The result is truncated to the desired `output_len`. + Computes Truncate(Permute(padded_input) + padded_input). + The padded input is the original vector zero-extended to the state width. + The feed-forward addition is part of the Poseidon1 design and is required + for security. + Used for hash chains and Merkle interior nodes. Args: - input_vec: The list of field elements to be hashed. - width: The state width of the Poseidon1 permutation (16 or 24). - output_len: The number of field elements in the output digest. + input_vec: Field elements to hash. + width: Permutation state width, either 16 or 24. + output_len: Number of output field elements to return. Returns: - A hash digest of `output_len` field elements. + Truncated digest of output_len field elements. """ - # Check that the input vector is long enough to produce the output. + # The output cannot be longer than the input vector after padding. if len(input_vec) < output_len: raise ValueError("Input vector is too short for requested output length.") - # Select the correct permutation parameters based on the state width. - if width not in (16, 24): - raise ValueError(f"Width must be 16 or 24, got {width}") + # Select the cached engine matching the requested permutation width. engine = self._get_engine(width) - # Create a padded input by extending with zeros to match the state width. + # Zero-pad to the state width before applying the permutation. padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec)) - # Apply the Poseidon1 permutation. + # Permute, then add the original padded input element-wise. permuted_state = engine.permute(padded_input) - - # Apply the feed-forward step, adding the input back element-wise. final_state = [p + i for p, i in zip(permuted_state, padded_input, strict=True)] - # Truncate the state to the desired output length and return. return final_state[:output_len] def safe_domain_separator(self, lengths: list[int], capacity_len: int) -> list[Fp]: - """ - Computes a unique domain separator for the sponge construction (SAFE API). + """Build a capacity initialization vector for the sponge construction. - A sponge's security relies on its initial state being unique for each distinct - hashing task. This function creates a unique "configuration" or - "initialization vector" (`capacity_value`) by hashing the high-level - parameters of the sponge's usage (e.g., the dimensions of the data - being hashed). This prevents multi-user or cross-context attacks. + Hashes the packed length parameters into a fixed-size capacity value. + This prevents collisions between sponges that absorb data of different shapes. Args: - lengths: A list of integer parameters that define the hash context. - capacity_len: The desired length of the output capacity value. + lengths: Integer parameters that define the hash context. + capacity_len: Number of field elements in the returned capacity value. Returns: - A list of `capacity_len` field elements for initializing the sponge. + A capacity vector of length capacity_len. """ - # Pack all the length parameters into a single, large, unambiguous integer. + # Pack all lengths into a single unambiguous integer using 32-bit slots. acc = 0 for length in lengths: acc = (acc << 32) | length - # Decompose this integer into a fixed-size list of field elements. - # - # This list serves as the input to a one-off compression hash. - # NOTE: we always use this mode with a 24 width. + # Compress the decomposed vector through the width-24 engine. + # Width 24 is the only mode used for sponge domain separation. input_vec = int_to_base_p(acc, 24) - - # Compress the decomposed vector to produce the capacity value. return self.compress(input_vec, 24, capacity_len) def sponge( @@ -143,80 +114,158 @@ def sponge( output_len: int, width: int, ) -> list[Fp]: - """ - Implements the Poseidon1 hash using the **sponge construction**. - - This mode is used for hashing large or variable-length inputs. In this scheme, - it is specifically used to hash the Merkle tree leaves, which consist of many - concatenated hash digests. - - ### Sponge Algorithm - - 1. **Initialization**: The internal state is divided into a `rate` (for data) - and a `capacity` (for security). The `capacity` part is initialized - with the domain-separating `capacity_value`. - - 2. **Absorbing**: The input data is processed in `rate`-sized chunks. In each - step, a chunk is added to the `rate` part of the state, and then the - entire state is scrambled by the `permute` function. + """Poseidon1 in sponge mode. - 3. **Squeezing**: Once all input is absorbed, the `rate` part of the state is - extracted as output. If more output is needed, the state is permuted again, - and more is extracted, repeating until `output_len` elements are generated. + Phase 1: load capacity, zero-extend input to a multiple of the rate. + Phase 2: absorb each rate-sized chunk by replacement, then permute. + Phase 3: squeeze the rate slots until output_len elements are produced. Args: - input_vec: The input data of arbitrary length. - capacity_value: The domain-separating value from `safe_domain_separator`. - output_len: The number of field elements in the final output digest. - width: The width of the Poseidon1 permutation. + input_vec: Variable-length input. + capacity_value: Domain-separating capacity initialization. + output_len: Desired output length in field elements. + width: Permutation state width. Returns: - A hash digest of `output_len` field elements. + A digest of output_len field elements. """ - # Ensure that the capacity value is not too long. + # The capacity must leave at least one rate slot for absorbing input. if len(capacity_value) >= width: raise ValueError("Capacity length must be smaller than the state width.") - # Determine the permutation parameters and the size of the rate. - if width not in (16, 24): - raise ValueError(f"Width must be 16 or 24, got {width}") engine = self._get_engine(width) rate = width - len(capacity_value) - # Pad the input vector with zeros to be an exact multiple of the rate size. + # Zero-pad to a multiple of the rate so absorption iterates exact chunks. num_extra = (rate - (len(input_vec) % rate)) % rate padded_input = input_vec + [Fp(value=0)] * num_extra - # Initialize the state: - # - capacity part (domain separator) at the beginning, - # - rate part (zero) follows. + # Layout: capacity slots first, then rate slots. cap_len = len(capacity_value) state = [Fp(value=0)] * width state[:cap_len] = capacity_value - # Absorb the input in rate-sized chunks via replacement. - for i in range(0, len(padded_input), rate): - chunk = padded_input[i : i + rate] - # Replace the rate part of the state with the chunk. - for j in range(rate): - state[cap_len + j] = chunk[j] - # Apply the cryptographic permutation to mix the state. + # Phase 2: absorb each chunk by overwriting the rate slots. + for chunk in batched(padded_input, rate): + for j, value in enumerate(chunk): + state[cap_len + j] = value state = engine.permute(state) - # Squeeze the output until enough elements have been generated. + # Phase 3: squeeze rate slots, permuting until enough output is available. output: list[Fp] = [] while len(output) < output_len: - # Extract the rate part of the state (after capacity) as output. output.extend(state[cap_len : cap_len + rate]) - # Permute the state. state = engine.permute(state) - # Truncate to the final output length and return. return output[:output_len] + def tweak_hash( + self, + config: XmssConfig, + parameter: Parameter, + tweak: TreeTweak | ChainTweak, + message_parts: list[HashDigestVector], + ) -> HashDigestVector: + """Apply the tweakable hash to one or more digests. + + Mode selection: + + - One digest input uses width-16 compression for hash chains. + - Two digest inputs use width-24 compression for Merkle interior nodes. + - More inputs use sponge mode for Merkle leaves. + + Args: + config: Active XMSS configuration. + parameter: Public parameter that personalizes the hash. + tweak: Position tweak for domain separation. + message_parts: Digests to hash together. + + Returns: + A digest of HASH_LEN_FE field elements. + """ + # Pack the tweak fields into one integer, then split it into base-P field elements. + # + # The low byte is a per-shape prefix. + # It stops a tree tweak and a chain tweak from packing to the same value. + # That keeps Merkle hashing domain-separated from chain hashing. + # + # Every other field sits in its own bit range above the prefix. + match tweak: + case TreeTweak(level=level, index=index): + acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE + case ChainTweak(epoch=epoch, chain_index=chain_index, step=step): + acc = (int(epoch) << 24) | (chain_index << 16) | (step << 8) | TWEAK_PREFIX_CHAIN + encoded_tweak = int_to_base_p(acc, config.TWEAK_LEN_FE) + + if len(message_parts) == 1: + # Hash chain step: width-16 compression of (digest || parameter || tweak). + input_vec = message_parts[0].elements + parameter.elements + encoded_tweak + result = self.compress(input_vec, 16, config.HASH_LEN_FE) + + elif len(message_parts) == 2: + # Merkle node: width-24 compression of (parameter || tweak || left || right). + input_vec = ( + parameter.elements + + encoded_tweak + + message_parts[0].elements + + message_parts[1].elements + ) + result = self.compress(input_vec, 24, config.HASH_LEN_FE) + + else: + # Merkle leaf: sponge mode over many concatenated digests. + flattened_message = [elem for part in message_parts for elem in part.elements] + input_vec = parameter.elements + encoded_tweak + flattened_message + + # The domain separator binds the sponge to this hashing task shape. + lengths = [ + config.PARAMETER_LEN, + config.TWEAK_LEN_FE, + config.DIMENSION, + config.HASH_LEN_FE, + ] + capacity_value = self.safe_domain_separator(lengths, config.CAPACITY) + result = self.sponge(input_vec, capacity_value, config.HASH_LEN_FE, 24) + + return HashDigestVector(data=result) + + def hash_chain( + self, + config: XmssConfig, + parameter: Parameter, + epoch: Uint64, + chain_index: int, + start_step: int, + num_steps: int, + start_digest: HashDigestVector, + ) -> HashDigestVector: + """Iterate the tweakable hash along a Winternitz chain. + + Each iteration uses a distinct chain tweak so every step is domain-separated. + + Args: + config: Active XMSS configuration. + parameter: Public parameter that personalizes the hash. + epoch: Slot identifier for the one-time signature. + chain_index: Index of the chain within the one-time signature. + start_step: Step number of the input digest. + num_steps: Number of additional hash applications. + start_digest: Digest at start_step. + + Returns: + Digest at start_step + num_steps. + """ + current_digest = start_digest + for i in range(num_steps): + # Steps are 1-indexed: step 1 is the first hash after the chain start. + tweak = ChainTweak(epoch=epoch, chain_index=chain_index, step=start_step + i + 1) + current_digest = self.tweak_hash(config, parameter, tweak, [current_digest]) + return current_digest + PROD_POSEIDON = PoseidonXmss(params16=PARAMS_16, params24=PARAMS_24) -"""An instance configured for production-level parameters.""" +"""Poseidon1 engine with production parameters.""" TEST_POSEIDON = PROD_POSEIDON -"""Test and production use the same Poseidon1 parameters; only XmssConfig differs.""" +"""Test environment reuses the production Poseidon1 parameters. +Only the surrounding configuration differs between modes.""" diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index b90bed660..fc086ebdf 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -1,209 +1,120 @@ -""" -Defines the pseudorandom function (PRF) used in the signature scheme. - -PRF based on the SHAKE128 extendable-output function (XOF). +"""SHAKE128-based pseudorandom function for deterministic key derivation. -The PRF is used to derive the secret starting points of the hash chains -for each epoch from a single master secret key. +Derives hash-chain starts and signing randomness from one master key. +Every call is domain-separated so the same key never collides across contexts. """ -from __future__ import annotations - import hashlib import os +from itertools import batched from typing import Final -from lean_spec.subspecs.koalabear import Fp -from lean_spec.types import Bytes32, StrictBaseModel, Uint64 +from lean_spec.types import Bytes32, Uint64 -from .constants import ( - PRF_KEY_LENGTH, - PROD_CONFIG, - TEST_CONFIG, - XmssConfig, -) +from ..koalabear import Fp +from .constants import PRF_KEY_LENGTH, XmssConfig from .types import HashDigestVector, PRFKey, Randomness PRF_DOMAIN_SEP: Final[bytes] = b"\xae\xae\x22\xff\x00\x01\xfa\xff\x21\xaf\x12\x00\x01\x11\xff\x00" -""" -A 16-byte domain separator to ensure PRF outputs are unique to this context. - -This prevents any potential conflicts if the same underlying hash function -(SHAKE128) were used for other purposes in the system. -""" +"""Fixed 16-byte domain separator used by every PRF call. +Prevents cross-context collisions if SHAKE128 is reused elsewhere in the system.""" PRF_DOMAIN_SEP_DOMAIN_ELEMENT: Final[bytes] = b"\x00" -""" -A 1-byte domain separator for deriving domain elements (used in `apply`). - -This distinguishes the PRF calls for generating hash chain starting points -from the PRF calls for generating randomness during signing. -""" +"""Subdomain tag for hash-chain start derivation.""" PRF_DOMAIN_SEP_RANDOMNESS: Final[bytes] = b"\x01" -""" -A 1-byte domain separator for deriving randomness (used in `get_randomness`). - -This distinguishes the PRF calls for generating signing randomness from the -PRF calls for generating domain elements, preventing any potential collisions -between the two use cases. -""" +"""Subdomain tag for signing-randomness derivation.""" PRF_BYTES_PER_FE: Final[int] = 16 -""" -The number of bytes of SHAKE128 output used to generate one field element. +"""SHAKE128 bytes consumed per output field element. +128 bits reduced modulo a 31-bit prime gives a statistical margin against bias.""" -We use 16 bytes (128 bits) of pseudorandom output, which is then reduced -modulo the 31-bit field prime `P`. This provides a significant statistical -safety margin to ensure the resulting field element is close to uniformly -random. -""" +def prf_key_gen() -> PRFKey: + """Generate a fresh master PRF key from the operating system entropy pool.""" + return PRFKey(os.urandom(PRF_KEY_LENGTH)) -def _bytes_to_field_elements(data: bytes, count: int) -> list[Fp]: - """ - Convert PRF output bytes into a list of field elements. - Each field element is derived from `PRF_BYTES_PER_FE` bytes, - interpreted as a big-endian integer and reduced modulo the field prime. +def prf_apply( + config: XmssConfig, key: PRFKey, epoch: Uint64, chain_index: Uint64 +) -> HashDigestVector: + """Derive the secret start of one Winternitz hash chain. - The extra bits provide statistical uniformity. + Args: + config: Active XMSS configuration. + key: Master PRF key. + epoch: Slot identifier for this one-time signature instance. + chain_index: Position of the chain within the one-time signature. + + Returns: + A hash digest used as the chain start. + """ + # Layout: + # + # domain_sep || 0x00 || key || epoch (4 bytes) || chain_index (8 bytes) + # + # The 0x00 byte separates chain-start derivation from randomness derivation. + input_data = ( + PRF_DOMAIN_SEP + + PRF_DOMAIN_SEP_DOMAIN_ELEMENT + + key + + epoch.to_bytes(4, "big") + + chain_index.to_bytes(8, "big") + ) + + # Pull enough SHAKE128 bytes to fill HASH_LEN_FE field elements. + num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE + prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) + return HashDigestVector( + data=[ + Fp(value=int.from_bytes(bytes(chunk), "big")) + for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) + ] + ) + + +def prf_get_randomness( + config: XmssConfig, + key: PRFKey, + epoch: Uint64, + message: Bytes32, + counter: Uint64, +) -> Randomness: + """Derive deterministic randomness for a signing attempt. + + Same construction as the chain-start derivation, with a different subdomain + tag. + Including the message and counter makes signing reproducible without + breaking security: signing twice with the same key, epoch, and message + always produces the same randomness. Args: - data: Raw bytes from SHAKE128 output. Must be exactly `count * PRF_BYTES_PER_FE` bytes. - count: Number of field elements to extract. + config: Active XMSS configuration. + key: Master PRF key. + epoch: Slot identifier for this signature. + message: Full message being signed. + counter: Attempt number, incremented when a previous attempt aborted. Returns: - List of `count` field elements. + Randomness used to encode the message into a valid codeword. """ - return [ - Fp(value=int.from_bytes(data[i : i + PRF_BYTES_PER_FE], "big")) - for i in range(0, count * PRF_BYTES_PER_FE, PRF_BYTES_PER_FE) - ] - - -class Prf(StrictBaseModel): - """An instance of the SHAKE128-based PRF for a given config.""" - - config: XmssConfig - """Configuration parameters for the PRF.""" - - def key_gen(self) -> PRFKey: - """ - Generates a cryptographically secure random key for the PRF. - - This function sources randomness from the operating system's - entropy pool. - - Returns: - A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. - """ - return PRFKey(os.urandom(PRF_KEY_LENGTH)) - - def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVector: - """ - Applies the PRF to derive the secret starting value for a single hash chain. - - ### PRF Construction - - The function constructs a unique input for the underlying SHAKE128 function - by concatenating several components: - `SHAKE128(DOMAIN_SEP || 0x00 || key || epoch || chain_index)` - - The 0x00 byte distinguishes this use case (deriving domain elements) from - randomness generation (which uses 0x01). The arbitrary-length output of - SHAKE128 is then processed to produce a list of field elements, which - serves as the secret starting digest for one chain. - - Args: - key: The secret master PRF key. - epoch: The epoch number, identifying the one-time signature instance. - chain_index: The index of the hash chain within that epoch's OTS. - - Returns: - A hash digest representing the secret start of a single hash chain. - """ - # Retrieve the scheme's configuration parameters. - config = self.config - - # Construct the unique input for the PRF by concatenating its components: - # - # - Domain Separation: Uniquely tag the PRF for this specific use case. - # - Domain Element Tag: 0x00 byte to distinguish from randomness generation. - # - Key Input: The master secret key. - # - Epoch: A 4-byte integer ensuring every epoch derives a different set of secrets. - # - Chain Index: An 8-byte integer ensuring each parallel hash chain gets a unique secret. - input_data = ( - PRF_DOMAIN_SEP - + PRF_DOMAIN_SEP_DOMAIN_ELEMENT - + key - + epoch.to_bytes(4, "big") - + chain_index.to_bytes(8, "big") - ) - - # Determine the total number of bytes to extract from the SHAKE output. - # - # We need enough bytes to produce `HASH_LEN_FE` field elements. - num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE - prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) - - # Convert the raw byte output into a list of field elements. - return HashDigestVector(data=_bytes_to_field_elements(prf_output_bytes, config.HASH_LEN_FE)) - - def get_randomness( - self, key: PRFKey, epoch: Uint64, message: Bytes32, counter: Uint64 - ) -> Randomness: - """ - Derives pseudorandom field elements for use in deterministic signing. - - This method is used to generate deterministic randomness for the Information - Encoding step during signing. By deriving randomness from the PRF key, epoch, - message, and attempt counter, we ensure that signing is deterministic: calling - `sign` twice with the same (sk, epoch, message) triple produces the same signature. - - This provides additional hardening against implementation errors where sign might - be called multiple times with the same epoch. However, calling sign with the same - epoch but *different* messages still compromises security. - - ### Construction - - Similar to `apply`, but includes the message and a counter in the input: - `SHAKE128(DOMAIN_SEP || 0x01 || key || epoch || message || counter)` - - The 0x01 byte distinguishes this use case (generating randomness) from - domain element derivation (which uses 0x00). - - Args: - key: The secret master PRF key. - epoch: The epoch number for this signature. - message: The message being signed (MESSAGE_LENGTH bytes). - counter: The attempt number (used when retrying encoding). - - Returns: - Randomness for encoding (i.e., `rho`). - """ - config = self.config - - # Construct input: DOMAIN_SEP || 0x01 || key || epoch || message || counter - input_data = ( - PRF_DOMAIN_SEP - + PRF_DOMAIN_SEP_RANDOMNESS - + key - + epoch.to_bytes(4, "big") - + message - + counter.to_bytes(8, "big") - ) - - # Extract enough bytes for RAND_LEN_FE field elements - num_bytes_to_read = PRF_BYTES_PER_FE * config.RAND_LEN_FE - prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) - - # Convert to field elements and wrap in Randomness - return Randomness(data=_bytes_to_field_elements(prf_output_bytes, config.RAND_LEN_FE)) - - -PROD_PRF = Prf(config=PROD_CONFIG) -"""An instance configured for production-level parameters.""" - -TEST_PRF = Prf(config=TEST_CONFIG) -"""A lightweight instance for test environments.""" + # Layout: + # + # domain_sep || 0x01 || key || epoch || message || counter + input_data = ( + PRF_DOMAIN_SEP + + PRF_DOMAIN_SEP_RANDOMNESS + + key + + epoch.to_bytes(4, "big") + + message + + counter.to_bytes(8, "big") + ) + + num_bytes_to_read = PRF_BYTES_PER_FE * config.RAND_LEN_FE + prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) + return Randomness( + data=[ + Fp(value=int.from_bytes(bytes(chunk), "big")) + for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) + ] + ) diff --git a/src/lean_spec/subspecs/xmss/rand.py b/src/lean_spec/subspecs/xmss/rand.py deleted file mode 100644 index fa84e80ce..000000000 --- a/src/lean_spec/subspecs/xmss/rand.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Random data generator for the XMSS signature scheme.""" - -import secrets - -from lean_spec.types import StrictBaseModel - -from ..koalabear import Fp, P -from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .types import HashDigestVector, Parameter - - -class Rand(StrictBaseModel): - """An instance of the random data generator for a given config.""" - - config: XmssConfig - """Configuration parameters for the random generator.""" - - def field_elements(self, length: int) -> list[Fp]: - """Generates a random list of field elements.""" - # For each element, generate a secure random integer in the range [0, P-1]. - return [Fp(value=secrets.randbelow(P)) for _ in range(length)] - - def parameter(self) -> Parameter: - """Generates a random public parameter.""" - return Parameter(data=self.field_elements(self.config.PARAMETER_LEN)) - - def domain(self) -> HashDigestVector: - """Generates a random hash digest.""" - return HashDigestVector(data=self.field_elements(self.config.HASH_LEN_FE)) - - -PROD_RAND = Rand(config=PROD_CONFIG) -"""An instance configured for production-level parameters.""" - -TEST_RAND = Rand(config=TEST_CONFIG) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py deleted file mode 100644 index baf66bf04..000000000 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ /dev/null @@ -1,615 +0,0 @@ -""" -Subtree construction and manipulation for top-bottom Merkle tree traversal. - -This module contains the `HashSubTree` type and its associated construction methods, -implementing the memory-efficient top-bottom tree traversal approach. -""" - -from typing import Self - -from lean_spec.types import Uint64 -from lean_spec.types.container import Container - -from .constants import XmssConfig -from .prf import Prf -from .rand import Rand -from .tweak_hash import TreeTweak, TweakHasher -from .types import ( - HashDigestList, - HashDigestVector, - HashTreeLayer, - HashTreeLayers, - HashTreeOpening, - Parameter, - PRFKey, -) -from .utils import get_padded_layer - - -class HashSubTree(Container): - """ - Represents a subtree of a sparse Merkle tree. - - This is the building block for the top-bottom tree traversal approach, - which splits a large Merkle tree into: - - **One top tree**: Contains the root and the top `LOG_LIFETIME/2` layers - - **Multiple bottom trees**: Each contains `sqrt(LIFETIME)` leaves - - A subtree can represent either a complete tree (from layer 0) or a partial tree - starting from a higher layer (like a top tree starting from layer `LOG_LIFETIME/2`). - - The layers are stored from `lowest_layer` up to the root, with padding applied - to ensure even alignment for efficient parent computation. - - Memory Efficiency - ----------------- - For a key with lifetime 2^32: - - Traditional approach: O(2^32) = requires hundreds of GiB - - Top-bottom approach: O(sqrt(2^32)) = O(2^16) ≈ much less memory - - The secret key maintains: - - The full top tree (sparse, only active roots) - - Two consecutive bottom trees (sliding window) - - SSZ Container with fields: - - depth: uint64 - - lowest_layer: uint64 - - layers: List[HashTreeLayer, LAYERS_LIMIT] - - Serialization is handled automatically by SSZ. - """ - - depth: Uint64 - """ - The total depth of the full tree (e.g., 32 for a 2^32 leaf space). - - This represents the depth of the complete Merkle tree, not just this subtree. - A subtree starting from layer `k` will have `depth - k` layers stored. - """ - - lowest_layer: Uint64 - """ - The lowest layer included in this subtree. - - - For bottom trees: `lowest_layer = 0` (includes leaves) - - For top trees: `lowest_layer = LOG_LIFETIME/2` (starts from middle) - - Example: For LOG_LIFETIME=32, top tree has lowest_layer=16, containing - layers 16 through 32 (the root). - """ - - layers: HashTreeLayers - """ - The layers of this subtree, from `lowest_layer` to the root. - - SSZ notation: `List[HashTreeLayer, LAYERS_LIMIT]` - - - `layers[0]` corresponds to layer `lowest_layer` in the full tree - - `layers[-1]` corresponds to the highest layer in this subtree - - For bottom trees: the last layer contains a single root - - For top trees: the last layer contains the global root - - Each layer maintains the padding invariant: start index is even, - end index is odd (except for single-node layers). - """ - - @classmethod - def new( - cls, - hasher: TweakHasher, - rand: Rand, - lowest_layer: Uint64, - depth: Uint64, - start_index: Uint64, - parameter: Parameter, - lowest_layer_nodes: list[HashDigestVector], - ) -> Self: - """ - Builds a new sparse Merkle subtree starting from a specified layer. - - This is the general constructor for subtrees and is used internally by - `new_top_tree()` and `new_bottom_tree()`. A subtree can start from any - layer, not just layer 0 (leaves). - - ### Construction Algorithm - - 1. **Initialization**: Start with the provided nodes at `lowest_layer`, - apply padding to ensure even alignment. - - 2. **Bottom-Up Iteration**: Build the tree layer by layer, from the - lowest layer towards the root. - - 3. **Parent Generation**: At each level, group nodes into pairs - (left, right) and hash them to create parent nodes. - - 4. **Padding**: Apply padding to each new layer to maintain the - even-alignment invariant. - - 5. **Termination**: Continue until reaching the root or the desired - highest layer. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - lowest_layer: The starting layer for this subtree (0 for full trees). - depth: The total depth of the full tree (e.g., 32 for 2^32 leaves). - start_index: The absolute index at `lowest_layer` where this subtree begins. - parameter: The public parameter `P` for the hash function. - lowest_layer_nodes: The hash nodes at `lowest_layer` to build from. - - Returns: - A `HashSubTree` containing all computed layers from `lowest_layer` to root. - """ - # Validate: nodes must fit in available positions at this layer. - max_positions = 1 << int(depth - lowest_layer) - if int(start_index) + len(lowest_layer_nodes) > max_positions: - raise ValueError( - f"Overflow at layer {lowest_layer}: " - f"start={start_index}, count={len(lowest_layer_nodes)}, max={max_positions}" - ) - - # Initialize with padded input layer. - layers: list[HashTreeLayer] = [] - current = get_padded_layer(rand, lowest_layer_nodes, start_index) - layers.append(current) - - # Build upward: hash pairs of children to create parents. - for level in range(lowest_layer, depth): - parent_start = current.start_index // Uint64(2) - - # Hash each pair of siblings into their parent using zip for cleaner indexing. - node_pairs = zip(current.nodes[::2], current.nodes[1::2], strict=True) - parents = [ - hasher.apply( - parameter, - TreeTweak(level=level + 1, index=parent_start + Uint64(i)), - [left, right], - ) - for i, (left, right) in enumerate(node_pairs) - ] - - # Pad and store the new layer. - current = get_padded_layer(rand, parents, parent_start) - layers.append(current) - - return cls( - depth=depth, - lowest_layer=lowest_layer, - layers=HashTreeLayers(data=layers), - ) - - @classmethod - def new_top_tree( - cls, - hasher: TweakHasher, - rand: Rand, - depth: int, - start_bottom_tree_index: Uint64, - parameter: Parameter, - bottom_tree_roots: list[HashDigestVector], - ) -> Self: - """ - Constructs a top tree from the roots of bottom trees. - - For top-bottom tree traversal, the full Merkle tree is split into: - - **Top tree**: Contains root and top `LOG_LIFETIME/2` layers - - **Bottom trees**: Each contains `sqrt(LIFETIME)` leaves - - The top tree's lowest layer contains the roots of all bottom trees, - and it is built upward from there to the global root. - - ### Algorithm - - 1. **Determine lowest layer**: For a tree of depth `d`, the top tree - starts at layer `d/2` (the middle of the tree). - - 2. **Build upward**: Use `new()` to build from the bottom tree - roots up to the global root. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - depth: The total depth of the full tree (must be even for top-bottom split). - start_bottom_tree_index: The index of the first bottom tree in the range. - parameter: The public parameter `P` for the hash function. - bottom_tree_roots: The list of roots from all bottom trees in order. - - Returns: - A `HashSubTree` representing the top tree with `lowest_layer = depth/2`. - - Raises: - ValueError: If depth is odd (top-bottom split requires even depth). - """ - if depth % 2 != 0: - raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") - - # Build from middle layer using bottom tree roots as leaves. - return cls.new( - hasher=hasher, - rand=rand, - lowest_layer=Uint64(depth // 2), - depth=Uint64(depth), - start_index=start_bottom_tree_index, - parameter=parameter, - lowest_layer_nodes=bottom_tree_roots, - ) - - @classmethod - def new_bottom_tree( - cls, - hasher: TweakHasher, - rand: Rand, - depth: int, - bottom_tree_index: Uint64, - parameter: Parameter, - leaves: list[HashDigestVector], - ) -> Self: - """ - Constructs a single bottom tree from leaf hashes. - - A bottom tree covers `sqrt(LIFETIME)` consecutive epochs. For a tree with - `LOG_LIFETIME = 32`, each bottom tree covers 2^16 = 65536 epochs. - - Bottom trees are numbered 0, 1, 2, ... where tree `i` covers epochs - `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`. - - ### Algorithm - - 1. **Build full tree**: First, build a complete subtree from layer 0 - using the provided leaves. - - 2. **Truncate incompatible top layers**: The full tree computation adds - padding nodes in upper layers that would be incompatible with other - bottom trees. We remove these layers. - - 3. **Replace with standalone root**: Extract the root at layer `depth/2` - and make it the highest layer of this bottom tree. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - depth: The total depth of the full tree (must be even). - bottom_tree_index: The index of this bottom tree (0, 1, 2, ...). - parameter: The public parameter `P` for the hash function. - leaves: The pre-hashed leaf nodes (one-time public keys). - - Returns: - A `HashSubTree` with layers 0 through `depth/2`, where the highest - layer contains only the bottom tree's root. - - Raises: - ValueError: If depth is odd or leaves count doesn't match `sqrt(LIFETIME)`. - """ - if depth % 2 != 0: - raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") - - # Each bottom tree has exactly sqrt(LIFETIME) leaves. - leaves_per_tree = 1 << (depth // 2) - if len(leaves) != leaves_per_tree: - raise ValueError( - f"Expected {leaves_per_tree} leaves for depth={depth}, got {len(leaves)}." - ) - - # Build full tree from leaves. - full_tree = cls.new( - hasher=hasher, - rand=rand, - lowest_layer=Uint64(0), - depth=Uint64(depth), - start_index=bottom_tree_index * Uint64(leaves_per_tree), - parameter=parameter, - lowest_layer_nodes=leaves, - ) - - # Extract root from middle layer. - middle = full_tree.layers[depth // 2] - root_idx = int(bottom_tree_index - middle.start_index) - root_layer = HashTreeLayer( - start_index=bottom_tree_index, - nodes=HashDigestList(data=[middle.nodes[root_idx]]), - ) - - # Keep bottom half + single root node. - truncated = list(full_tree.layers[: depth // 2]) - return cls( - depth=Uint64(depth), - lowest_layer=Uint64(0), - layers=HashTreeLayers(data=truncated + [root_layer]), - ) - - @classmethod - def from_prf_key( - cls, - prf: Prf, - hasher: TweakHasher, - rand: Rand, - config: XmssConfig, - prf_key: PRFKey, - bottom_tree_index: Uint64, - parameter: Parameter, - ) -> Self: - """ - Generates a single bottom tree on-demand from the PRF key. - - This is a key component of the top-bottom tree approach: instead of storing all - one-time secret keys, we regenerate them on-demand using the PRF. This enables - O(sqrt(LIFETIME)) memory usage. - - ### Algorithm - - 1. **Determine epoch range**: Bottom tree `i` covers epochs - `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))` - - 2. **Generate leaves**: For each epoch in parallel: - - For each chain (0 to DIMENSION-1): - - Derive secret start: `PRF(prf_key, epoch, chain_index)` - - Compute public end: hash chain for `BASE - 1` steps - - Hash all chain ends to get the leaf - - 3. **Build bottom tree**: Construct the bottom tree from the leaves - - Args: - prf: The PRF instance for key derivation. - hasher: The tweakable hash instance. - rand: Random generator for padding values. - config: The XMSS configuration. - prf_key: The master PRF secret key. - bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...). - parameter: The public parameter `P` for the hash function. - - Returns: - A `HashSubTree` representing the requested bottom tree. - """ - # Calculate the number of leaves per bottom tree: sqrt(LIFETIME). - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - - # Determine the epoch range for this bottom tree. - start_epoch = bottom_tree_index * Uint64(leaves_per_bottom_tree) - end_epoch = start_epoch + Uint64(leaves_per_bottom_tree) - - # Generate leaf hashes for all epochs in this bottom tree. - leaf_hashes: list[HashDigestVector] = [] - - for epoch in range(start_epoch, end_epoch): - # For each epoch, compute the one-time public key (chain endpoints). - chain_ends: list[HashDigestVector] = [] - - for chain_index in range(config.DIMENSION): - # Derive the secret start of the chain from the PRF key. - start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index)) - - # Compute the public end by hashing BASE - 1 times. - end_digest = hasher.hash_chain( - parameter=parameter, - epoch=Uint64(epoch), - chain_index=chain_index, - start_step=0, - num_steps=config.BASE - 1, - start_digest=start_digest, - ) - chain_ends.append(end_digest) - - # Hash the chain ends to get the leaf for this epoch. - leaf_tweak = TreeTweak(level=0, index=Uint64(epoch)) - leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends) - leaf_hashes.append(leaf_hash) - - # Build the bottom tree from the leaf hashes. - return cls.new_bottom_tree( - hasher=hasher, - rand=rand, - depth=config.LOG_LIFETIME, - bottom_tree_index=bottom_tree_index, - parameter=parameter, - leaves=leaf_hashes, - ) - - def root(self) -> HashDigestVector: - """ - Extracts the root digest from this subtree. - - For top-bottom tree traversal, a subtree's root is the single node - in its highest layer. - - Returns: - The root hash digest of the subtree. - - Raises: - ValueError: If the subtree has no layers or the highest layer is empty. - """ - if not self.layers: - raise ValueError("Empty subtree has no root.") - if not self.layers[-1].nodes: - raise ValueError("Top layer is empty.") - return self.layers[-1].nodes[0] - - def path(self, position: Uint64) -> HashTreeOpening: - """ - Computes the authentication path for a leaf within this subtree. - - This is similar to full tree path computation but works with subtrees that may - not start from layer 0. The path is computed from the specified position up to - (but not including) the subtree's root. - - For a subtree covering layers L through H (where H is the highest/root layer), - this generates H - L siblings: one for each layer from L to H-1. - - Args: - position: The absolute index of the leaf in the full tree coordinate system. - - Returns: - A `HashTreeOpening` containing the sibling hashes for the path. - - Raises: - ValueError: If the subtree is empty or the position is out of bounds. - """ - if not self.layers: - raise ValueError("Empty subtree.") - - # Check bounds. - first = self.layers[0] - if not (first.start_index <= position < first.start_index + Uint64(len(first.nodes))): - raise ValueError(f"Position {position} out of bounds.") - - # Collect sibling at each layer (except root). - siblings: list[HashDigestVector] = [] - pos = position - - # Iterate over all layers except the last (root). - for layer in self.layers[:-1]: - # Sibling index: flip last bit of position, adjust for layer offset. - sibling_idx = int((pos ^ Uint64(1)) - layer.start_index) - if not (0 <= sibling_idx < len(layer.nodes)): - raise ValueError(f"Sibling index {sibling_idx} out of bounds.") - - siblings.append(layer.nodes[sibling_idx]) - pos = pos // Uint64(2) # Move to parent position. - - return HashTreeOpening(siblings=HashDigestList(data=siblings)) - - -def combined_path( - top_tree: HashSubTree, - bottom_tree: HashSubTree, - position: Uint64, -) -> HashTreeOpening: - """ - Generates a combined authentication path spanning top and bottom trees. - - For top-bottom tree traversal, a signature's authentication path must prove - that a leaf is part of the global Merkle root. This requires two proofs: - - 1. **Bottom tree path**: Proves the leaf is part of its bottom tree's root - 2. **Top tree path**: Proves the bottom tree's root is part of the global root - - This function combines both paths into a single `HashTreeOpening` that can - be used for verification. - - ### Algorithm - - 1. **Determine which bottom tree**: Calculate which bottom tree contains - the specified position. - - 2. **Get bottom tree path**: Extract the authentication path from the leaf - up to the bottom tree's root (depth/2 siblings). - - 3. **Get top tree path**: Extract the authentication path from the bottom - tree's root up to the global root (depth/2 siblings). - - 4. **Concatenate**: Combine both paths into a single path with `depth` siblings. - - Args: - top_tree: The top tree containing the global root. - bottom_tree: The bottom tree containing the specified position. - position: The absolute epoch/leaf index to generate a path for. - - Returns: - A `HashTreeOpening` with `depth` siblings that authenticates the leaf - against the global root. - - Raises: - ValueError: If trees have mismatched depths, odd depth, or position is - out of bounds for the bottom tree. - """ - # Validate matching depths. - if top_tree.depth != bottom_tree.depth: - raise ValueError(f"Depth mismatch: top={top_tree.depth}, bottom={bottom_tree.depth}.") - - depth = int(top_tree.depth) - if depth % 2 != 0: - raise ValueError(f"Depth must be even, got {depth}.") - - # Validate bottom tree matches position. - leaves_per_tree = Uint64(1 << (depth // 2)) - expected_start = (position // leaves_per_tree) * leaves_per_tree - if bottom_tree.layers[0].start_index != expected_start: - raise ValueError( - f"Wrong bottom tree: position {position} needs start {expected_start}, " - f"got {bottom_tree.layers[0].start_index}." - ) - - # Concatenate: bottom path + top path. - bottom_path = bottom_tree.path(position) - top_path = top_tree.path(position // leaves_per_tree) - combined = tuple(bottom_path.siblings.data) + tuple(top_path.siblings.data) - - return HashTreeOpening(siblings=HashDigestList(data=combined)) - - -def verify_path( - hasher: TweakHasher, - parameter: Parameter, - root: HashDigestVector, - position: Uint64, - leaf_parts: list[HashDigestVector], - opening: HashTreeOpening, -) -> bool: - """ - Verifies a Merkle authentication path against a known, trusted root. - - This function is the final check in signature verification. It proves that the - one-time public key used for the signature (represented by `leaf_parts`) is a - legitimate member of the set committed to by the Merkle `root`. - - ### Verification Algorithm - - 1. **Leaf Computation**: The process begins at the bottom. The verifier first - hashes the `leaf_parts` to compute the actual leaf digest. This becomes the - starting `current_node` for the climb up the tree. - - 2. **Bottom-Up Reconstruction**: The verifier iterates through the `opening.siblings` - path. At each `level`, it takes the `current_node` and the `sibling_node` - from the path. - - 3. **Parent Calculation**: It determines if the `current_node` is a left or - right child based on its `position`. The two nodes are placed in the - correct `(left, right)` order and hashed (with the correct `TreeTweak`) - to compute the parent. This parent becomes the `current_node` for the - next level. - - 4. **Final Comparison**: After all siblings are used, the final `current_node` - is the candidate root. The path is valid if and only if it matches the trusted `root`. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - parameter: The public parameter `P` for the hash function. - root: The known, trusted Merkle root from the public key. - position: The absolute index of the leaf being verified. - leaf_parts: The list of digests that constitute the original leaf. - opening: The `HashTreeOpening` object containing the sibling path. - - Returns: - `True` if the path is valid and reconstructs the root, `False` otherwise. - Returns `False` for invalid inputs (depth > 32 or position out of bounds). - """ - # Validate depth and position bounds. - # - # These checks guard against malformed attacker-controlled input. - # Return False instead of raising to avoid panic on invalid signatures. - depth = len(opening.siblings) - if depth > 32: - return False - if int(position) >= (1 << depth): - return False - - # Start: hash leaf parts to get leaf node. - current = hasher.apply( - parameter, - TreeTweak(level=0, index=Uint64(position)), - leaf_parts, - ) - pos = int(position) - - # Walk up: hash current with each sibling. - for level, sibling in enumerate(opening.siblings): - # Left child has even position, right child has odd. - left, right = (current, sibling) if pos % 2 == 0 else (sibling, current) - pos //= 2 # Parent position. - current = hasher.apply( - parameter, - TreeTweak(level=level + 1, index=Uint64(pos)), - [left, right], - ) - - # Valid if we reconstructed the expected root. - return current == root diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py deleted file mode 100644 index 985b29c38..000000000 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Implements the Top Level Target Sum Winternitz incomparable encoding scheme. - -This module provides the logic for converting a message hash into a valid -codeword for the one-time signature part of the scheme. It acts as a filter on -top of the message hash output. -""" - -from __future__ import annotations - -from lean_spec.types import Bytes32, StrictBaseModel, Uint64 - -from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .message_hash import ( - PROD_MESSAGE_HASHER, - TEST_MESSAGE_HASHER, - MessageHasher, -) -from .types import Parameter, Randomness - - -class TargetSumEncoder(StrictBaseModel): - """ - An instance of the Target Sum encoder for a given configuration. - - This class encapsulates the logic for validating a message hash against the - scheme's target sum constraint. - """ - - config: XmssConfig - """Configuration parameters for the encoder.""" - - message_hasher: MessageHasher - """Message hasher for encoding.""" - - def encode( - self, parameter: Parameter, message: Bytes32, rho: Randomness, epoch: Uint64 - ) -> list[int] | None: - """ - Encodes a message into a codeword if it meets the target sum criteria. - - ### Encoding Algorithm - - 1. **Hashing to a Vertex**: The function first hashes the inputs (`message`, - `rho`, etc.) to produce a candidate codeword. This can be viewed as - mapping the inputs to a vertex in a high-dimensional hypercube, where - the vertex's coordinates are the digits of the codeword. - - 2. **Target Sum Validation**: It then checks if the sum of the candidate's digits - matches the scheme's predefined `TARGET_SUM`. This is equivalent to - verifying that the vertex lies on the correct hypercube layer. This - constraint is critical for the scheme's security and ensures a - predictable number of hash operations during signature verification. - - Args: - parameter: The public parameter `P`, used for domain separation. - message: The message to encode. - rho: The randomness used for this specific encoding attempt. - epoch: The current epoch, used as part of the hash input. - - Returns: - - The codeword (a list of integers) if the sum matches the target. - - Otherwise, it returns `None` to signal that this attempt failed and - a new `rho` must be tried. - """ - # Hash the inputs to map them to a potential codeword (a vertex in the hypercube). - codeword_candidate = self.message_hasher.apply(parameter, epoch, rho, message) - - # The aborting decode may reject if a field element equals P - 1. - if codeword_candidate is None: - return None - - # A codeword is valid only if it lies on the predefined hypercube layer. - # - # This is verified by checking if the sum of its coordinates equals TARGET_SUM. - if sum(codeword_candidate) == self.config.TARGET_SUM: - # If the sum is correct, this is a valid codeword for the one-time signature. - return codeword_candidate - - # If the sum does not match, this `rho` is invalid for this message. - # The caller will need to try again with new randomness. - return None - - -PROD_TARGET_SUM_ENCODER = TargetSumEncoder(config=PROD_CONFIG, message_hasher=PROD_MESSAGE_HASHER) -"""An instance configured for production-level parameters.""" - -TEST_TARGET_SUM_ENCODER = TargetSumEncoder(config=TEST_CONFIG, message_hasher=TEST_MESSAGE_HASHER) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py deleted file mode 100644 index d83c2ea70..000000000 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ /dev/null @@ -1,253 +0,0 @@ -""" -Defines the Tweakable Hash function using Poseidon1. - -### The Problem: Hash Function Overload - -In a complex cryptographic scheme like XMSS, a single hash function (like Poseidon1) -is used for many different purposes: -1. Hashing iteratively to form **hash chains**. -2. Hashing pairs of nodes to build the **Merkle tree**. -3. Hashing the one-time public key to form a **Merkle leaf**. - -If we simply called `hash(data)` for all these cases, we could run into a critical -security issue: a "collision" between different contexts. For example, the output of a -hash in a chain might accidentally be identical to the hash of two nodes in the tree. -This could allow an attacker to create forgeries. - -### The Solution: Tweakable Hashing - -A **tweakable hash function** solves this by treating each hash computation as having a -unique "address" or "tweak". - -The function's signature becomes `hash(tweak, data)`. By ensuring every tweak is -unique across the entire scheme, we guarantee that every hash computation is -**domain-separated**, eliminating the risk of cross-context collisions. -""" - -from typing import NamedTuple - -from lean_spec.types import StrictBaseModel, Uint64 - -from ..koalabear import Fp -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - TWEAK_PREFIX_CHAIN, - TWEAK_PREFIX_TREE, - XmssConfig, -) -from .poseidon import ( - PROD_POSEIDON, - TEST_POSEIDON, - PoseidonXmss, -) -from .types import HashDigestVector, Parameter -from .utils import int_to_base_p - - -class TreeTweak(NamedTuple): - """ - A tweak used for hashing nodes within the Merkle tree. - - This structure ensures that every hash computed during the construction of the - Merkle tree has a unique context. - """ - - level: int - """The level (height) in the Merkle tree, where 0 is the leaf level.""" - - index: Uint64 - """The node's index (from the left) within that level.""" - - -class ChainTweak(NamedTuple): - """ - A tweak used for hashing elements within a WOTS+ hash chain. - - This structure ensures every iterative hash within every one-time signature - chain is distinct across all epochs. - """ - - epoch: Uint64 - """The signature epoch.""" - - chain_index: int - """The index of the hash chain (from 0 to DIMENSION-1).""" - - step: int - """The step number within the chain (from 1 to BASE-1).""" - - -class TweakHasher(StrictBaseModel): - """An instance of the Tweakable Hasher for a given config.""" - - config: XmssConfig - """Configuration parameters for the hasher.""" - - poseidon: PoseidonXmss - """Poseidon permutation instance for hashing.""" - - def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: - """ - Encodes a structured tweak object into a list of field elements. - - It converts a high-level tweak context (like "Merkle tree, level 5, index 3") - into a low-level format that can be consumed by the Poseidon1 hash function. - - ### Encoding Algorithm - - 1. **Packing**: The integer components of the tweak are packed into a - single, large integer using bit-shifting. A unique prefix - (`TWEAK_PREFIX_TREE` or `TWEAK_PREFIX_CHAIN`) is included to - guarantee that `TreeTweak` and `ChainTweak` can never produce - the same integer. This process is injective (one-to-one). - - 2. **Decomposition**: The resulting large integer is then decomposed into a - list of base-`P` digits, where `P` is the field prime. This produces the - final list of field elements. - - Args: - tweak: The `TreeTweak` or `ChainTweak` object to encode. - length: The desired number of field elements in the output list. - - Returns: - A list of `length` field elements representing the encoded tweak. - """ - # Pack the tweak's integer fields into a single large integer. - # - # A hardcoded prefix is included for domain separation between tweak types. - match tweak: - case TreeTweak(level=level, index=index): - # Packing scheme: (level << 40) | (index << 8) | PREFIX - acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE - case ChainTweak(epoch=epoch, chain_index=chain_index, step=step): - # Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX - acc = (int(epoch) << 24) | (chain_index << 16) | (step << 8) | TWEAK_PREFIX_CHAIN - - # Decompose the packed integer `acc` into a list of base-P field elements. - return int_to_base_p(acc, length) - - def apply( - self, - parameter: Parameter, - tweak: TreeTweak | ChainTweak, - message_parts: list[HashDigestVector], - ) -> HashDigestVector: - """ - Applies the tweakable Poseidon1 hash function to a message. - - This is the main entry point for all internal hashing operations. It prepares - the inputs and routes them to the appropriate Poseidon1 function based on - the input size, ensuring optimal performance and security. - - ### Hashing Algorithm - - 1. **Input Assembly**: The final hash input is formed by concatenating: - `[parameter || encoded_tweak || flattened_message_parts]` - - 2. **Mode Selection**: - - For small inputs (1 or 2 `HashDigest` parts), it uses the highly - efficient **compression mode** of Poseidon1. - - For large inputs (many `HashDigest` parts, like a Merkle leaf), - it uses the more flexible **sponge mode**. - - Args: - parameter: The public parameter `P` for this key pair. - tweak: A `TreeTweak` or `ChainTweak` for domain separation. - message_parts: A list of one or more hash digests to be hashed together. - - Returns: - A new hash digest of `HASH_LEN_FE` field elements. - """ - # Get the config for this scheme. - config = self.config - - # Encode the high-level tweak structure into a list of field elements. - encoded_tweak = self._encode_tweak(tweak, config.TWEAK_LEN_FE) - - # Route to the correct Poseidon1 mode based on the input size. - if len(message_parts) == 1: - # Case 1: Hashing a single digest (used in hash chains). - # - # We use the efficient width-16 compression mode. - input_vec = message_parts[0].elements + parameter.elements + encoded_tweak - result = self.poseidon.compress(input_vec, 16, config.HASH_LEN_FE) - - elif len(message_parts) == 2: - # Case 2: Hashing two digests (used for Merkle tree nodes). - # - # We use the slightly larger width-24 compression mode. - input_vec = ( - parameter.elements - + encoded_tweak - + message_parts[0].elements - + message_parts[1].elements - ) - result = self.poseidon.compress(input_vec, 24, config.HASH_LEN_FE) - - else: - # Case 3: Hashing many digests (used for the Merkle tree leaf). - # - # We use the robust sponge mode. - # First, flatten the list of message parts into a single vector. - flattened_message = [elem for part in message_parts for elem in part.elements] - input_vec = parameter.elements + encoded_tweak + flattened_message - - # Create a domain separator for the sponge mode based on the input dimensions. - # - # This ensures the sponge is uniquely configured for this specific hashing task. - lengths = [ - config.PARAMETER_LEN, - config.TWEAK_LEN_FE, - config.DIMENSION, - config.HASH_LEN_FE, - ] - capacity_value = self.poseidon.safe_domain_separator(lengths, config.CAPACITY) - - result = self.poseidon.sponge(input_vec, capacity_value, config.HASH_LEN_FE, 24) - - return HashDigestVector(data=result) - - def hash_chain( - self, - parameter: Parameter, - epoch: Uint64, - chain_index: int, - start_step: int, - num_steps: int, - start_digest: HashDigestVector, - ) -> HashDigestVector: - """ - Performs repeated hashing to traverse a WOTS+ hash chain. - - This function iteratively calls the main `apply` method, creating a new, - unique `ChainTweak` for each step to ensure every hash in the sequence - is domain-separated. - - Args: - parameter: The public parameter `P`. - epoch: The signature epoch, part of the tweak. - chain_index: The index of the hash chain, part of the tweak. - start_step: The starting step number in the chain. - num_steps: The number of hashing steps to perform. - start_digest: The digest to begin hashing from. - - Returns: - The final hash digest after `num_steps` applications. - """ - current_digest = start_digest - for i in range(num_steps): - # Create a unique tweak for the current position in the chain. - # - # The `step` is `start_step + i + 1` because steps are 1-indexed. - tweak = ChainTweak(epoch=epoch, chain_index=chain_index, step=start_step + i + 1) - # Apply the hash function to get the next digest in the chain. - current_digest = self.apply(parameter, tweak, [current_digest]) - return current_digest - - -PROD_TWEAK_HASHER = TweakHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) -"""An instance configured for production-level parameters.""" - -TEST_TWEAK_HASHER = TweakHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/types.py b/src/lean_spec/subspecs/xmss/types.py index 41fe7fc00..62efae62a 100644 --- a/src/lean_spec/subspecs/xmss/types.py +++ b/src/lean_spec/subspecs/xmss/types.py @@ -1,6 +1,6 @@ """Base types for the XMSS signature scheme.""" -from typing import Final +from typing import Final, NamedTuple from lean_spec.subspecs.koalabear import Fp @@ -11,102 +11,100 @@ from .constants import PRF_KEY_LENGTH, TARGET_CONFIG -class PRFKey(BaseBytes): +class TreeTweak(NamedTuple): + """Tweak that domain-separates Merkle node hashes by their position.""" + + level: int + """Height in the Merkle tree. + + Layer 0 is the leaf level. + """ + + index: Uint64 + """Node index within its level, counted from the left.""" + + +class ChainTweak(NamedTuple): + """Tweak that domain-separates Winternitz chain hashes by their position.""" + + epoch: Uint64 + """Slot identifier for the one-time signature.""" + + chain_index: int + """Index of the chain within the one-time signature.""" + + step: int """ - The PRF master secret key. + - Step number along the chain. + - Steps are 1-indexed. + - Step zero is the chain start. + """ + - This is a high-entropy byte string that acts as the single root secret from - which all one-time signing keys are deterministically derived. +class PRFKey(BaseBytes): + """The PRF master secret key. + + High-entropy byte string acting as the single root secret. + Every one-time signing key is deterministically derived from this seed. """ LENGTH = PRF_KEY_LENGTH HASH_DIGEST_LENGTH: Final = TARGET_CONFIG.HASH_LEN_FE -""" -The fixed length of a hash digest in field elements. +"""Length of one hash digest in field elements. -Derived from `TARGET_CONFIG.HASH_LEN_FE`. This corresponds to the output length -of the Poseidon1 hash function used in the XMSS scheme. -""" +Corresponds to the Poseidon1 output length used in the XMSS scheme.""" -# Calculate the maximum number of nodes in a sparse Merkle tree layer: -# - A bottom tree has at most 2^(LOG_LIFETIME/2) leaves -# - With padding, we may add up to 2 additional nodes -# - To be generous and future-proof, we use 2^(LOG_LIFETIME/2 + 1) +# Why: a bottom tree spans 2^(LOG_LIFETIME/2) leaves. +# Padding may add up to two extra siblings. +# Doubling that bound leaves room for future-proof layouts without resizing. NODE_LIST_LIMIT: Final = 1 << (TARGET_CONFIG.LOG_LIFETIME // 2 + 1) -""" -The maximum number of nodes that can be stored in a sparse Merkle tree layer. - -Calculated as `2^(LOG_LIFETIME/2 + 1)` from TARGET_CONFIG to accommodate: -- Bottom trees with up to `2^(LOG_LIFETIME/2)` nodes -- Padding overhead (up to 2 additional nodes) -- Future-proofing with 2x margin -""" +"""Maximum number of nodes that can be stored in a sparse Merkle tree layer.""" class HashDigestVector(SSZVector[Fp]): - """ - A single hash digest represented as a fixed-size vector of field elements. + """A single hash digest as a fixed-size vector of field elements. - This is the SSZ-compliant representation of a Poseidon1 hash output. - In SSZ notation: `Vector[Fp, HASH_DIGEST_LENGTH]` - - The fixed size enables efficient serialization when used in collections, - as SSZ can pack these back-to-back without per-element offsets. + The fixed size lets SSZ pack these back-to-back without per-element offsets. """ LENGTH = HASH_DIGEST_LENGTH class HashDigestList(SSZList[HashDigestVector]): - """ - Variable-length list of hash digests. - - In SSZ notation: `List[Vector[Fp, HASH_DIGEST_LENGTH], NODE_LIST_LIMIT]` - - This type is used to represent collections of hash digests in the XMSS scheme. - """ + """Variable-length list of hash digests.""" LIMIT = NODE_LIST_LIMIT class Parameter(SSZVector[Fp]): - """ - The public parameter P. + """The public parameter P. - This is a unique, randomly generated value associated with a single key pair. It - is mixed into every hash computation to "personalize" the hash function, preventing - certain cross-key attacks. It is public knowledge. + Unique, randomly generated value associated with a single key pair. + Mixed into every hash to personalize the function and block cross-key attacks. + Public knowledge. """ LENGTH = TARGET_CONFIG.PARAMETER_LEN class Randomness(SSZVector[Fp]): - """ - The randomness `rho` (ρ) used during signing. - - This value provides a variable input to the message hash, allowing the signer to - repeatedly try hashing until a valid "codeword" is found. It must be included in - the final signature for the verifier to reproduce the same hash. + """The randomness rho used during signing. - SSZ notation: `Vector[Fp, RAND_LEN_FE]` + Variable input to the message hash so the signer can resample until a + valid codeword is found. + Included in the final signature so the verifier reproduces the hash. """ LENGTH = TARGET_CONFIG.RAND_LEN_FE class HashTreeOpening(Container): - """ - A Merkle authentication path. - - This object contains the minimal proof required to connect a specific leaf - to the Merkle root. It consists of the list of all sibling nodes along the - path from the leaf to the top of the tree. + """A Merkle authentication path. - SSZ Container with fields: - - siblings: List[Vector[Fp, HASH_DIGEST_LENGTH], NODE_LIST_LIMIT] + Contains the minimal proof connecting a specific leaf to the Merkle root. + Holds every sibling node along the path from the leaf to the tree top. """ siblings: HashDigestList @@ -114,11 +112,9 @@ class HashTreeOpening(Container): class HashTreeLayer(Container): - """ - Represents a single horizontal "slice" of the sparse Merkle tree. + """A single horizontal slice of the sparse Merkle tree. - Because the tree is sparse, we only store the nodes that are actually computed - for the active range of leaves, not the entire conceptual layer. + The tree is sparse: only nodes computed for the active leaf range are stored. """ start_index: Uint64 @@ -127,27 +123,17 @@ class HashTreeLayer(Container): """SSZ-compliant list of hash digests stored for this layer.""" +# Why: layers run from 0 (leaves) up to LOG_LIFETIME (root) inclusive. LAYERS_LIMIT: Final = TARGET_CONFIG.LOG_LIFETIME + 1 -""" -The maximum number of layers in a subtree. - -This is `LOG_LIFETIME + 1` to accommodate all layers from 0 (leaves) to LOG_LIFETIME (root), -inclusive. For example, with LOG_LIFETIME=32, this allows up to 33 layers. -""" +"""Maximum number of layers in a subtree.""" class HashTreeLayers(SSZList[HashTreeLayer]): - """ - Variable-length list of Merkle tree layers. - - In SSZ notation: `List[HashTreeLayer, LAYERS_LIMIT]` - - This type represents the layers of a subtree, from the lowest layer up to the root. + """Variable-length list of Merkle tree layers. - The number of layers varies based on the subtree structure: - - Bottom trees: `LOG_LIFETIME/2` layers - - Top trees: `LOG_LIFETIME/2` layers - - Maximum: `LOG_LIFETIME + 1` layers + Represents the layers of a subtree, from the lowest layer up to the root. + Bottom and top trees each cover half the depth. + The cap allows the full tree. """ LIMIT = LAYERS_LIMIT diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py deleted file mode 100644 index 623c3b765..000000000 --- a/src/lean_spec/subspecs/xmss/utils.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Utility functions for the XMSS signature scheme.""" - -from __future__ import annotations - -from ...types.uint import Uint64 -from ..koalabear import Fp, P -from .rand import Rand -from .types import HashDigestList, HashDigestVector, HashTreeLayer - - -def get_padded_layer( - rand: Rand, nodes: list[HashDigestVector], start_index: Uint64 -) -> HashTreeLayer: - """ - Pads a layer of nodes with random hashes to simplify tree construction. - - This helper enforces a crucial invariant: every active layer must start at an - even index and end at an odd index. This guarantees that every node within - the layer can be neatly paired with a sibling (a left child with a right - child), which dramatically simplifies the parent generation logic by - removing the need to handle edge cases. - - Args: - rand: Random generator for padding values. - nodes: The list of active nodes for the current layer. - start_index: The starting index of the first node in `nodes`. - - Returns: - A new `HashTreeLayer` with the necessary padding applied. - """ - nodes_with_padding: list[HashDigestVector] = [] - end_index = start_index + Uint64(len(nodes)) - Uint64(1) - - # Prepend random padding if the layer starts at an odd index. - if start_index % Uint64(2) == Uint64(1): - nodes_with_padding.append(rand.domain()) - - # The actual start index of the padded layer is always the even - # number at or immediately before the original start_index. - actual_start_index = start_index - (start_index % Uint64(2)) - - # Add the actual node content. - nodes_with_padding.extend(nodes) - - # Append random padding if the layer ends at an even index. - if end_index % Uint64(2) == Uint64(0): - nodes_with_padding.append(rand.domain()) - - return HashTreeLayer( - start_index=actual_start_index, nodes=HashDigestList(data=nodes_with_padding) - ) - - -def int_to_base_p(value: int, num_limbs: int) -> list[Fp]: - """ - Decomposes a large integer into a list of base-P field elements. - - This function performs a standard base conversion, where each "digit" - is an element in the prime field F_p. - - Args: - value: The integer to decompose. - num_limbs: The desired number of output field elements (limbs). - - Returns: - A list of `num_limbs` field elements representing the integer. - """ - limbs: list[Fp] = [] - acc = value - for _ in range(num_limbs): - limbs.append(Fp(value=acc)) - acc //= P - return limbs - - -def expand_activation_time( - log_lifetime: int, desired_activation_slot: int, desired_num_active_slots: int -) -> tuple[int, int]: - """ - Expands and aligns the activation time to top-bottom tree boundaries. - - For efficient top-bottom tree traversal, activation intervals must be aligned to - `sqrt(LIFETIME)` boundaries. This function takes the user's desired activation - interval and expands it to meet the following requirements: - - 1. **Start alignment**: Start slot is rounded down to a multiple of `sqrt(LIFETIME)` - 2. **End alignment**: End slot is rounded up to a multiple of `sqrt(LIFETIME)` - 3. **Minimum duration**: At least `2 * sqrt(LIFETIME)` slots (two bottom trees) - 4. **Lifetime bounds**: Clamped to `[0, LIFETIME)` - - ### Algorithm - - Let `C = 2^(LOG_LIFETIME/2) = sqrt(LIFETIME)` - - 1. Align start downward: `start = desired_start & c_mask` where `c_mask = ~(C - 1)` - 2. Round end upward: `end = (desired_end + C - 1) & c_mask` - 3. Enforce minimum: `if end - start < 2*C: end = start + 2*C` - 4. Clamp to bounds: Adjust if end exceeds `C^2 = LIFETIME` - - ### Example - - For `LOG_LIFETIME = 32` (LIFETIME = 2^32, C = 2^16 = 65536): - - Request: slots [10000, 80000) → 70000 slots - - Aligned: slots [0, 131072) → 131072 slots = 2 bottom trees - - Args: - log_lifetime: The logarithm (base 2) of the total lifetime. - desired_activation_slot: The user's requested first slot. - desired_num_active_slots: The user's requested number of slots. - - Returns: - A tuple `(start_bottom_tree_index, end_bottom_tree_index)` where: - - `start_bottom_tree_index`: Index of the first bottom tree (0, 1, 2, ...) - - `end_bottom_tree_index`: Index past the last bottom tree (exclusive) - - Actual slots: `[start_index * C, end_index * C)` - """ - # Calculate sqrt(LIFETIME) and the alignment mask. - c = 1 << (log_lifetime // 2) # C = 2^(LOG_LIFETIME/2) - c_mask = ~(c - 1) # Mask for rounding to multiples of C - - # Calculate the desired end slot. - desired_end_slot = desired_activation_slot + desired_num_active_slots - - # Step 1: Align start downward to a multiple of C. - start = desired_activation_slot & c_mask - - # Step 2: Round end upward to a multiple of C. - end = (desired_end_slot + c - 1) & c_mask - - # Step 3: Enforce minimum duration of 2*C. - if end - start < 2 * c: - end = start + 2 * c - - # Step 4: Clamp to lifetime bounds [0, C^2). - lifetime = c * c # LIFETIME = C^2 = 2^LOG_LIFETIME - if end > lifetime: - # If the expanded interval exceeds the lifetime, try to fit it at the end. - duration = end - start - - if duration > lifetime: - # The expanded interval is larger than the entire lifetime. - # Use the entire lifetime. - start = 0 - end = lifetime - else: - # Shift the interval to end at the lifetime boundary. - end = lifetime - start = (lifetime - duration) & c_mask # Keep alignment - - # Convert to bottom tree indices. - # Bottom tree i covers slots [i*C, (i+1)*C). - start_bottom_tree_index = start // c - end_bottom_tree_index = end // c - - return (start_bottom_tree_index, end_bottom_tree_index) diff --git a/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py b/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py index abbc30064..9fdc82dff 100644 --- a/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py +++ b/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py @@ -23,7 +23,6 @@ Checkpoint, Slot, ValidatorIndex, - ValidatorIndices, ) from tests.lean_spec.helpers import make_store @@ -581,8 +580,7 @@ def test_attestation_target_after_on_block( proposer_pubkey = key_manager.get_public_keys(proposer_1)[1] proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=ValidatorIndices(data=[proposer_1]).to_aggregation_bits(), + raw_xmss=[(proposer_1, proposer_pubkey, proposer_signature)], message=block_root, slot=slot_1, ) diff --git a/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py b/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py index 5810ac00b..4b6f82998 100644 --- a/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py +++ b/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py @@ -22,7 +22,6 @@ Checkpoint, Slot, ValidatorIndex, - ValidatorIndices, ) from tests.lean_spec.helpers import ( TEST_VALIDATOR_ID, @@ -293,16 +292,15 @@ def test_valid_proof_stored_correctly( data_root = hash_tree_root(attestation_data) # Create valid aggregated proof - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -340,16 +338,15 @@ def test_attestation_data_used_as_key( data_root = hash_tree_root(attestation_data) - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -380,16 +377,15 @@ def test_invalid_proof_rejected(self, key_manager: XmssKeyManager, spec: LstarSp data_root = hash_tree_root(attestation_data) - xmss_participants = ValidatorIndices(data=signers).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in signers], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in signers], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in signers + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -426,19 +422,15 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager, spec: Lst # First proof: validators 1 and 2 participants_1 = [ValidatorIndex(1), ValidatorIndex(2)] - xmss_1 = ValidatorIndices(data=participants_1).to_aggregation_bits() - raw_xmss_1 = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants_1], - [ - key_manager.sign_attestation_data(vid, attestation_data) - for vid in participants_1 - ], - strict=True, + raw_xmss_1 = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants_1 + ] proof_1 = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_1, children=[], raw_xmss=raw_xmss_1, message=data_root, @@ -447,19 +439,15 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager, spec: Lst # Second proof: validators 1 and 3 (validator 1 overlaps) participants_2 = [ValidatorIndex(1), ValidatorIndex(3)] - xmss_2 = ValidatorIndices(data=participants_2).to_aggregation_bits() - raw_xmss_2 = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants_2], - [ - key_manager.sign_attestation_data(vid, attestation_data) - for vid in participants_2 - ], - strict=True, + raw_xmss_2 = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants_2 + ] proof_2 = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_2, children=[], raw_xmss=raw_xmss_2, message=data_root, diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index cbb33e1f6..766dbef49 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -410,16 +410,15 @@ def make_aggregated_proof( ) -> TypeOneMultiSignature: """Create a valid Type-1 aggregated proof for the given participants.""" data_root = hash_tree_root(attestation_data) - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager.get_public_keys(vid)[0] for vid in participants], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], - strict=True, + raw_xmss = [ + ( + vid, + key_manager.get_public_keys(vid)[0], + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants + ] return TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -480,8 +479,7 @@ def make_signed_block_from_store( proposer_signature = key_manager.sign_block_root(proposer_index, slot, block_root) proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=ValidatorIndices(data=[proposer_index]).to_aggregation_bits(), + raw_xmss=[(proposer_index, proposer_pubkey, proposer_signature)], message=block_root, slot=slot, ) diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index f1bed1e77..d6609d0fa 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -26,7 +26,7 @@ from lean_spec.subspecs.validator.registry import ValidatorEntry from lean_spec.subspecs.xmss import TARGET_SIGNATURE_SCHEME from lean_spec.subspecs.xmss.aggregation import TypeOneMultiSignature, TypeTwoMultiSignature -from lean_spec.types import Bytes32, Slot, Uint64, ValidatorIndex, ValidatorIndices +from lean_spec.types import Bytes32, Slot, Uint64, ValidatorIndex from tests.lean_spec.helpers import ( TEST_VALIDATOR_ID, MockNetworkRequester, @@ -1077,11 +1077,9 @@ async def test_block_includes_pending_attestations( signatures.append(sig) public_keys.append(key_manager[vid].attestation_keypair.public_key) - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() proof = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=list(zip(public_keys, signatures, strict=True)), - xmss_participants=xmss_participants, + raw_xmss=list(zip(participants, public_keys, signatures, strict=True)), message=data_root, slot=attestation_data.slot, ) diff --git a/tests/lean_spec/subspecs/xmss/test_aggregation.py b/tests/lean_spec/subspecs/xmss/test_aggregation.py index 4a8bd8213..1e345d0c2 100644 --- a/tests/lean_spec/subspecs/xmss/test_aggregation.py +++ b/tests/lean_spec/subspecs/xmss/test_aggregation.py @@ -16,7 +16,6 @@ Checkpoint, Slot, ValidatorIndex, - ValidatorIndices, ) from tests.lean_spec.helpers import make_attestation_data_simple, make_bytes32 @@ -31,16 +30,15 @@ def _sign_and_aggregate( att_data = make_attestation_data_simple(slot, make_bytes32(head), make_bytes32(target), source) data_root = hash_tree_root(att_data) - xmss_participants = ValidatorIndices(data=validator_ids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in validator_ids], - [key_manager.sign_attestation_data(vid, att_data) for vid in validator_ids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in validator_ids + ] return TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -48,35 +46,22 @@ def _sign_and_aggregate( ) -def test_aggregate_rejects_empty_inputs() -> None: - """Aggregation with no signatures and no children raises an error.""" - with pytest.raises(AggregationError, match="At least one raw signature or child proof"): - TypeOneMultiSignature.aggregate( - xmss_participants=None, - children=[], - raw_xmss=[], - message=make_bytes32(0), - slot=Slot(0), - ) - - def test_aggregate_multiple_signatures(key_manager: XmssKeyManager) -> None: """Multiple validators' signatures can be aggregated into a single Type-1 proof.""" source = Checkpoint(root=make_bytes32(10), slot=Slot(0)) att_data = make_attestation_data_simple(Slot(2), make_bytes32(11), make_bytes32(12), source) vids = [ValidatorIndex(i) for i in range(4)] - xmss_participants = ValidatorIndices(data=vids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in vids], - [key_manager.sign_attestation_data(vid, att_data) for vid in vids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in vids + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=hash_tree_root(att_data), @@ -102,17 +87,16 @@ def test_aggregate_children_with_raw_signatures(key_manager: XmssKeyManager) -> # Additional raw signatures: validators 2, 3 extra_vids = [ValidatorIndex(2), ValidatorIndex(3)] - xmss_participants = ValidatorIndices(data=extra_vids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in extra_vids], - [key_manager.sign_attestation_data(vid, att_data) for vid in extra_vids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in extra_vids + ] parent = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[ ( child, @@ -149,7 +133,6 @@ def test_aggregate_three_children(key_manager: XmssKeyManager) -> None: child_c_pks = [key_manager[ValidatorIndex(2)].attestation_keypair.public_key] parent = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(child_a, child_a_pks), (child_b, child_b_pks), (child_c, child_c_pks)], raw_xmss=[], message=hash_tree_root(att_data), @@ -185,14 +168,12 @@ def test_aggregate_children_of_children(key_manager: XmssKeyManager) -> None: # Level 1: two intermediate proofs. mid_ab = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(leaf_a, leaf_a_pks), (leaf_b, leaf_b_pks)], raw_xmss=[], message=msg, slot=att_data.slot, ) mid_cd = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(leaf_c, leaf_c_pks), (leaf_d, leaf_d_pks)], raw_xmss=[], message=msg, @@ -201,7 +182,6 @@ def test_aggregate_children_of_children(key_manager: XmssKeyManager) -> None: # Level 2: final root proof. root = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(mid_ab, leaf_a_pks + leaf_b_pks), (mid_cd, leaf_c_pks + leaf_d_pks)], raw_xmss=[], message=msg, @@ -236,17 +216,16 @@ def test_aggregate_mixed_children_and_raw_multiple(key_manager: XmssKeyManager) # Additional raw signatures from validators 2 and 3. extra_vids = [ValidatorIndex(2), ValidatorIndex(3)] - xmss_participants = ValidatorIndices(data=extra_vids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in extra_vids], - [key_manager.sign_attestation_data(vid, att_data) for vid in extra_vids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in extra_vids + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[(child_a, child_a_pks), (child_b, child_b_pks)], raw_xmss=raw_xmss, message=msg, @@ -335,7 +314,6 @@ def test_aggregate_child_signed_different_message_fails(key_manager: XmssKeyMana # The binding rejects mismatching messages during recursive aggregation. with pytest.raises(AggregationError): TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(child_a, child_a_pks), (child_b, child_b_pks)], raw_xmss=[], message=hash_tree_root(att_data_b), @@ -343,60 +321,6 @@ def test_aggregate_child_signed_different_message_fails(key_manager: XmssKeyMana ) -def test_aggregate_rejects_single_child_without_raw(key_manager: XmssKeyManager) -> None: - """A single child without raw signatures is rejected (need at least two children).""" - placeholder = ByteList512KiB(data=b"\x00") - stub_child = TypeOneMultiSignature( - participants=ValidatorIndices(data=[ValidatorIndex(0)]).to_aggregation_bits(), - proof=placeholder, - ) - - with pytest.raises(AggregationError, match="At least two child proofs"): - TypeOneMultiSignature.aggregate( - xmss_participants=None, - children=[ - ( - stub_child, - [ - key_manager[ValidatorIndex(i)].attestation_keypair.public_key - for i in range(1) - ], - ) - ], - raw_xmss=[], - message=make_bytes32(0), - slot=Slot(0), - ) - - -def test_aggregate_rejects_mismatched_participant_count( - key_manager: XmssKeyManager, -) -> None: - """Participant bitfield count must match raw signature count.""" - source = Checkpoint(root=make_bytes32(60), slot=Slot(0)) - att_data = make_attestation_data_simple(Slot(7), make_bytes32(61), make_bytes32(62), source) - - # Claim 2 participants but only provide 1 signature. - xmss_participants = ValidatorIndices( - data=[ValidatorIndex(0), ValidatorIndex(1)] - ).to_aggregation_bits() - raw_xmss = [ - ( - key_manager[ValidatorIndex(0)].attestation_keypair.public_key, - key_manager.sign_attestation_data(ValidatorIndex(0), att_data), - ) - ] - - with pytest.raises(AggregationError, match="does not match"): - TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, - children=[], - raw_xmss=raw_xmss, - message=hash_tree_root(att_data), - slot=att_data.slot, - ) - - def test_type_two_aggregate_rejects_empty_parts() -> None: """Type-2 aggregation requires at least one Type-1 input.""" with pytest.raises(AggregationError, match="at least one Type-1 input"): diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py index 34ddb6ce4..2092df1d2 100644 --- a/tests/lean_spec/subspecs/xmss/test_interface.py +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -4,6 +4,7 @@ import pytest +from lean_spec.subspecs.xmss.encoding import target_sum_encode from lean_spec.subspecs.xmss.interface import ( TEST_SIGNATURE_SCHEME, GeneralizedXmssScheme, @@ -37,7 +38,7 @@ def _test_correctness_roundtrip( # Sign the message at the chosen slot. # - # This might take a moment as it may try multiple `rho` values. + # This might take a moment as it may try multiple rho values. signature = scheme.sign(sk, test_slot, message) # Verification of the valid signature must succeed. @@ -56,9 +57,11 @@ def _test_correctness_roundtrip( # In that case, verification will succeed, which is expected behavior for identical codewords. # # We detect this by checking if both messages encode to the same codeword. - original_codeword = scheme.encoder.encode(pk.parameter, message, signature.rho, test_slot) - tampered_codeword = scheme.encoder.encode( - pk.parameter, tampered_message, signature.rho, test_slot + original_codeword = target_sum_encode( + scheme.poseidon, scheme.config, pk.parameter, message, signature.rho, test_slot + ) + tampered_codeword = target_sum_encode( + scheme.poseidon, scheme.config, pk.parameter, tampered_message, signature.rho, test_slot ) if tampered_codeword != original_codeword: diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py index ae6094bc6..9b3d59823 100644 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -2,21 +2,23 @@ import pytest -from lean_spec.subspecs.xmss.rand import PROD_RAND, Rand -from lean_spec.subspecs.xmss.subtree import HashSubTree, verify_path -from lean_spec.subspecs.xmss.tweak_hash import ( - PROD_TWEAK_HASHER, +from lean_spec.subspecs.xmss.constants import PROD_CONFIG, XmssConfig +from lean_spec.subspecs.xmss.field import random_domain, random_parameter +from lean_spec.subspecs.xmss.merkle import HashSubTree, verify_path +from lean_spec.subspecs.xmss.poseidon import PROD_POSEIDON, PoseidonXmss +from lean_spec.subspecs.xmss.types import ( + HashDigestList, + HashDigestVector, + HashTreeOpening, TreeTweak, - TweakHasher, ) -from lean_spec.subspecs.xmss.types import HashDigestList, HashDigestVector, HashTreeOpening from lean_spec.types import Uint64 from lean_spec.types.exceptions import SSZValueError def _run_commit_open_verify_roundtrip( - hasher: TweakHasher, - rand: Rand, + poseidon: PoseidonXmss, + config: XmssConfig, num_leaves: int, depth: int, start_index: int, @@ -33,22 +35,23 @@ def _run_commit_open_verify_roundtrip( 5. Verify that each path is valid for its corresponding leaf and root. Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. num_leaves: The number of active leaves in the tree. depth: The total depth of the Merkle tree. start_index: The starting index of the first active leaf. leaf_parts_len: The number of digests that constitute a single leaf. """ # SETUP: Generate a random parameter and the raw leaf data. - parameter = rand.parameter() + parameter = random_parameter(config) leaves: list[list[HashDigestVector]] = [ - [rand.domain() for _ in range(leaf_parts_len)] for _ in range(num_leaves) + [random_domain(config) for _ in range(leaf_parts_len)] for _ in range(num_leaves) ] # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. leaf_hashes: list[HashDigestVector] = [ - hasher.apply( + poseidon.tweak_hash( + config, parameter, TreeTweak(level=0, index=Uint64(start_index + i)), leaf_parts, @@ -58,8 +61,8 @@ def _run_commit_open_verify_roundtrip( # COMMIT: Build the Merkle tree from the leaf hashes. tree = HashSubTree.new( - hasher=hasher, - rand=rand, + poseidon=poseidon, + config=config, lowest_layer=Uint64(0), depth=Uint64(depth), start_index=Uint64(start_index), @@ -73,7 +76,8 @@ def _run_commit_open_verify_roundtrip( position = Uint64(start_index + i) opening = tree.path(position) is_valid = verify_path( - hasher=hasher, + poseidon=poseidon, + config=config, parameter=parameter, root=root, position=position, @@ -107,7 +111,7 @@ def test_commit_open_verify_roundtrip( assert start_index + num_leaves <= (1 << depth) _run_commit_open_verify_roundtrip( - PROD_TWEAK_HASHER, PROD_RAND, num_leaves, depth, start_index, leaf_parts_len + PROD_POSEIDON, PROD_CONFIG, num_leaves, depth, start_index, leaf_parts_len ) @@ -127,29 +131,26 @@ def test_ssz_validation_rejects_excessive_depth(self) -> None: creating malformed openings at the SSZ level. The check in verify_path is defense-in-depth for deserialized data. """ - rand = PROD_RAND - # Attempting to create a list with 33 siblings raises at the type level. - excessive_siblings = [rand.domain() for _ in range(33)] + excessive_siblings = [random_domain(PROD_CONFIG) for _ in range(33)] with pytest.raises(SSZValueError): HashDigestList(data=excessive_siblings) def test_rejects_position_exceeding_tree_capacity(self) -> None: """verify_path returns False when position >= 2^depth.""" - rand = PROD_RAND - hasher = PROD_TWEAK_HASHER - parameter = rand.parameter() + parameter = random_parameter(PROD_CONFIG) - root = rand.domain() - leaf_parts = [rand.domain()] + root = random_domain(PROD_CONFIG) + leaf_parts = [random_domain(PROD_CONFIG)] # Create an opening with depth=4 (supports positions 0-15). - siblings = [rand.domain() for _ in range(4)] + siblings = [random_domain(PROD_CONFIG) for _ in range(4)] opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) # Position 16 is out of bounds for depth 4 (capacity = 2^4 = 16). result = verify_path( - hasher=hasher, + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, parameter=parameter, root=root, position=Uint64(16), @@ -160,7 +161,8 @@ def test_rejects_position_exceeding_tree_capacity(self) -> None: # Position 100 is also out of bounds. result = verify_path( - hasher=hasher, + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, parameter=parameter, root=root, position=Uint64(100), @@ -171,21 +173,20 @@ def test_rejects_position_exceeding_tree_capacity(self) -> None: def test_valid_position_at_boundary(self) -> None: """verify_path accepts position at maximum valid value (2^depth - 1).""" - rand = PROD_RAND - hasher = PROD_TWEAK_HASHER - parameter = rand.parameter() + parameter = random_parameter(PROD_CONFIG) - root = rand.domain() - leaf_parts = [rand.domain()] + root = random_domain(PROD_CONFIG) + leaf_parts = [random_domain(PROD_CONFIG)] # Create an opening with depth=4. - siblings = [rand.domain() for _ in range(4)] + siblings = [random_domain(PROD_CONFIG) for _ in range(4)] opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) # Position 15 is the maximum valid position for depth 4. # This should not return False due to bounds check (may still fail root check). result = verify_path( - hasher=hasher, + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, parameter=parameter, root=root, position=Uint64(15), diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py index 58861155a..6d42a43e4 100644 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -7,23 +7,25 @@ TEST_CONFIG, TWEAK_PREFIX_MESSAGE, ) -from lean_spec.subspecs.xmss.message_hash import ( - TEST_MESSAGE_HASHER, +from lean_spec.subspecs.xmss.encoding import ( + aborting_decode, + encode_epoch, + encode_message, + message_hash, ) -from lean_spec.subspecs.xmss.rand import TEST_RAND -from lean_spec.subspecs.xmss.types import Randomness -from lean_spec.subspecs.xmss.utils import int_to_base_p +from lean_spec.subspecs.xmss.field import int_to_base_p, random_field_elements +from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON +from lean_spec.subspecs.xmss.types import Parameter, Randomness from lean_spec.types import Bytes32, Uint64 def test_encode_message() -> None: - """Tests `encode_message` with various message patterns.""" + """Tests encode_message with various message patterns.""" config = TEST_CONFIG - hasher = TEST_MESSAGE_HASHER # All-zero message msg_zeros = Bytes32(b"\x00" * 32) - encoded_zeros = hasher.encode_message(msg_zeros) + encoded_zeros = encode_message(config, msg_zeros) assert len(encoded_zeros) == config.MSG_LEN_FE assert all(fe == Fp(value=0) for fe in encoded_zeros) @@ -31,14 +33,13 @@ def test_encode_message() -> None: msg_max = Bytes32(b"\xff" * 32) acc = int.from_bytes(msg_max, "little") expected_max = int_to_base_p(acc, config.MSG_LEN_FE) - assert hasher.encode_message(msg_max) == expected_max + assert encode_message(config, msg_max) == expected_max def test_encode_epoch() -> None: """ - Tests `encode_epoch` for correctness and injectivity. + Tests encode_epoch for correctness and injectivity. """ - hasher = TEST_MESSAGE_HASHER config = TEST_CONFIG # Test specific values from the Rust reference tests. @@ -46,27 +47,26 @@ def test_encode_epoch() -> None: for epoch in test_epochs: acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE expected = int_to_base_p(acc, config.TWEAK_LEN_FE) - assert hasher.encode_epoch(Uint64(epoch)) == expected + assert encode_epoch(config, Uint64(epoch)) == expected # Test for injectivity. It is highly unlikely for a collision to occur # with a few random samples if the encoding is injective. num_trials = 1000 seen_encodings: set[tuple[Fp, ...]] = set() for i in range(num_trials): - encoding = tuple(hasher.encode_epoch(Uint64(i))) + encoding = tuple(encode_epoch(config, Uint64(i))) assert encoding not in seen_encodings seen_encodings.add(encoding) def test_aborting_decode_known_decomposition() -> None: """Verifies aborting decode with a hand-computed example.""" - hasher = TEST_MESSAGE_HASHER config = TEST_CONFIG # Pick an arbitrary quotient multiplier to build a valid field element. d_value = 5 - fe_list = [Fp(value=config.Q * d_value)] * hasher.config.MH_HASH_LEN_FE - result = hasher._aborting_decode(fe_list) + fe_list = [Fp(value=config.Q * d_value)] * config.MH_HASH_LEN_FE + result = aborting_decode(config, fe_list) assert result is not None assert len(result) == config.DIMENSION @@ -77,45 +77,42 @@ def test_aborting_decode_known_decomposition() -> None: for _ in range(config.Z): digits_per_fe.append(remaining % config.BASE) remaining //= config.BASE - all_digits = (digits_per_fe * hasher.config.MH_HASH_LEN_FE)[: config.DIMENSION] + all_digits = (digits_per_fe * config.MH_HASH_LEN_FE)[: config.DIMENSION] assert result == all_digits def test_aborting_decode_boundary() -> None: """Tests that FE = P-2 succeeds and FE = P-1 aborts.""" - hasher = TEST_MESSAGE_HASHER config = TEST_CONFIG # P - 2 is the largest valid value (just below Q * BASE^Z = P - 1). - fe_valid = [Fp(value=P - 2)] * hasher.config.MH_HASH_LEN_FE - result = hasher._aborting_decode(fe_valid) + fe_valid = [Fp(value=P - 2)] * config.MH_HASH_LEN_FE + result = aborting_decode(config, fe_valid) assert result is not None assert len(result) == config.DIMENSION assert all(0 <= d < config.BASE for d in result) # P - 1 triggers the abort (A_i >= Q * BASE^Z). fe_abort = [Fp(value=P - 1)] - result = hasher._aborting_decode(fe_abort) + result = aborting_decode(config, fe_abort) assert result is None def test_apply_output_is_valid_codeword() -> None: """ - Tests that the output of `apply` is `None` or a valid codeword with - DIMENSION digits each in `[0, BASE-1]`. + Tests that the output of message_hash is None or a valid codeword with + DIMENSION digits each in [0, BASE-1]. """ config = TEST_CONFIG - hasher = TEST_MESSAGE_HASHER - rand = TEST_RAND # Setup with random inputs. - parameter = rand.parameter() + parameter = Parameter(data=random_field_elements(config.PARAMETER_LEN)) epoch = Uint64(313) - randomness = Randomness(data=rand.field_elements(config.RAND_LEN_FE)) + randomness = Randomness(data=random_field_elements(config.RAND_LEN_FE)) message = Bytes32(b"\xaa" * 32) # Call the message hash function. - result = hasher.apply(parameter, epoch, randomness, message) + result = message_hash(TEST_POSEIDON, config, parameter, epoch, randomness, message) # The aborting decode may return None, but in practice it almost never does. assert result is not None diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py index cf4c63e00..efa9fa3bc 100644 --- a/tests/lean_spec/subspecs/xmss/test_prf.py +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -4,29 +4,27 @@ PRF_KEY_LENGTH, TEST_CONFIG, ) -from lean_spec.subspecs.xmss.prf import TEST_PRF +from lean_spec.subspecs.xmss.prf import prf_apply, prf_key_gen from lean_spec.subspecs.xmss.types import PRFKey from lean_spec.types import Uint64 def test_key_gen_is_random() -> None: """ - Performs a sanity check on `key_gen` to ensure it's not deterministic + Performs a sanity check on key_gen to ensure it's not deterministic or producing trivial outputs. This test mirrors the logic from the reference Rust implementation. """ - prf = TEST_PRF - # Check that the key has the correct length. - key = prf.key_gen() + key = prf_key_gen() assert len(key) == PRF_KEY_LENGTH # Generate multiple keys and ensure they are not all identical. # # This is a basic check to ensure we are getting fresh randomness. num_trials = 10 - keys = {prf.key_gen() for _ in range(num_trials)} + keys = {prf_key_gen() for _ in range(num_trials)} assert len(keys) == num_trials # Check that the keys are not filled with a single repeated byte. @@ -35,7 +33,7 @@ def test_key_gen_is_random() -> None: # such a key, so this is a good health check. all_same_count = 0 for _ in range(num_trials): - key = prf.key_gen() + key = prf_key_gen() # A set will have size 1 if all elements are the same. if len(set(key)) == 1: all_same_count += 1 @@ -44,32 +42,31 @@ def test_key_gen_is_random() -> None: def test_apply_is_sensitive_to_inputs() -> None: """ - Tests that changing any input to `apply` results in a different output. + Tests that changing any input to apply results in a different output. This confirms that all parts of the input (key, epoch, chain_index) are being correctly absorbed by the hash function. """ - prf = TEST_PRF config = TEST_CONFIG # Generate a baseline output with a set of initial inputs. key1 = PRFKey(b"\x11" * PRF_KEY_LENGTH) epoch1 = Uint64(10) chain_index1 = Uint64(20) - baseline_output = prf.apply(key1, epoch1, chain_index1) + baseline_output = prf_apply(config, key1, epoch1, chain_index1) assert len(baseline_output) == config.HASH_LEN_FE # Test sensitivity to the key. key2 = PRFKey(b"\x22" * PRF_KEY_LENGTH) - output_key_changed = prf.apply(key2, epoch1, chain_index1) + output_key_changed = prf_apply(config, key2, epoch1, chain_index1) assert baseline_output != output_key_changed # Test sensitivity to the epoch. epoch2 = Uint64(11) - output_epoch_changed = prf.apply(key1, epoch2, chain_index1) + output_epoch_changed = prf_apply(config, key1, epoch2, chain_index1) assert baseline_output != output_epoch_changed # Test sensitivity to the chain_index. chain_index2 = Uint64(21) - output_index_changed = prf.apply(key1, epoch1, chain_index2) + output_index_changed = prf_apply(config, key1, epoch1, chain_index2) assert baseline_output != output_index_changed diff --git a/tests/lean_spec/subspecs/xmss/test_security_levels.py b/tests/lean_spec/subspecs/xmss/test_security_levels.py index 815d9bed9..86b4a46c8 100644 --- a/tests/lean_spec/subspecs/xmss/test_security_levels.py +++ b/tests/lean_spec/subspecs/xmss/test_security_levels.py @@ -10,7 +10,7 @@ The security analysis follows the framework of [DKKW25c] Section 6. Theorem 1 gives an advantage bound as the sum of five terms. Each term divided by attacker -running time must be at most `2^{-(k + log5)}`, yielding four independent +running time must be at most 2^{-(k + log5)}, yielding four independent constraints (Parameter Requirements 2 and 3): 1. Digest (SM-UD/SM-PRE via Eq 8-9 / Eq 15) @@ -20,7 +20,7 @@ The abort correction from [HKKTW26] Corollary 1 and Remark 14 adjusts the message hash bound: the aborting decode effectively enlarges the output space -to `|H|/(1 - theta)`, where `theta` is the abort probability. +to |H|/(1 - theta), where theta is the abort probability. """ import math @@ -51,10 +51,10 @@ def _compute_security_levels(config: XmssConfig) -> dict[str, float]: Returns a dict with keys: - - `k_classical`: effective classical security (bits) - - `k_quantum`: effective quantum security (bits) - - `expected_attempts`: expected signing attempts per message - - `signing_failure_log2`: log2 of probability that all MAX_TRIES attempts fail + - k_classical: effective classical security (bits) + - k_quantum: effective quantum security (bits) + - expected_attempts: expected signing attempts per message + - signing_failure_log2: log2 of probability that all MAX_TRIES attempts fail """ v = config.DIMENSION w_bits = int(math.log2(config.BASE)) @@ -188,8 +188,8 @@ def test_prod_abort_probability_is_negligible() -> None: The aborting decode rejection probability must be negligible. From [HKKTW26] Section 6.1: each FE has abort probability 1/P. - Over `ceil(v/Z)` FEs, the total abort probability is approximately - `ceil(v/Z) / P`. + Over ceil(v/Z) FEs, the total abort probability is approximately + ceil(v/Z) / P. """ config = PROD_CONFIG ell = math.ceil(config.DIMENSION / config.Z) diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py index 40d0bfbd8..dc790bf08 100644 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ b/tests/lean_spec/subspecs/xmss/test_utils.py @@ -7,15 +7,12 @@ from lean_spec.subspecs.koalabear.field import Fp, P from lean_spec.subspecs.xmss.constants import TEST_CONFIG -from lean_spec.subspecs.xmss.prf import TEST_PRF -from lean_spec.subspecs.xmss.rand import TEST_RAND -from lean_spec.subspecs.xmss.subtree import HashSubTree -from lean_spec.subspecs.xmss.tweak_hash import TEST_TWEAK_HASHER +from lean_spec.subspecs.xmss.field import int_to_base_p +from lean_spec.subspecs.xmss.interface import _expand_activation_time +from lean_spec.subspecs.xmss.merkle import HashSubTree +from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON +from lean_spec.subspecs.xmss.prf import prf_key_gen from lean_spec.subspecs.xmss.types import Parameter -from lean_spec.subspecs.xmss.utils import ( - expand_activation_time, - int_to_base_p, -) from lean_spec.types import Uint64 @@ -82,8 +79,8 @@ def test_expand_activation_time( expected_start_tree: int, expected_end_tree: int, ) -> None: - """Tests that expand_activation_time correctly aligns and expands activation intervals.""" - start_tree, end_tree = expand_activation_time(log_lifetime, desired_activation, desired_num) + """Tests that _expand_activation_time correctly aligns and expands activation intervals.""" + start_tree, end_tree = _expand_activation_time(log_lifetime, desired_activation, desired_num) assert start_tree == expected_start_tree assert end_tree == expected_end_tree @@ -114,7 +111,7 @@ def test_hash_subtree_from_prf_key() -> None: config = TEST_CONFIG # Generate a PRF key - prf_key = TEST_PRF.key_gen() + prf_key = prf_key_gen() # Generate a random parameter parameter = Parameter( @@ -123,9 +120,7 @@ def test_hash_subtree_from_prf_key() -> None: # Generate bottom tree 0 bottom_tree = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, + poseidon=TEST_POSEIDON, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -150,16 +145,14 @@ def test_hash_subtree_from_prf_key() -> None: def test_hash_subtree_from_prf_key_deterministic() -> None: """Tests that HashSubTree.from_prf_key is deterministic.""" config = TEST_CONFIG - prf_key = TEST_PRF.key_gen() + prf_key = prf_key_gen() parameter = Parameter( data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] ) # Generate the same bottom tree twice tree1 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, + poseidon=TEST_POSEIDON, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -167,9 +160,7 @@ def test_hash_subtree_from_prf_key_deterministic() -> None: ) tree2 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, + poseidon=TEST_POSEIDON, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -183,16 +174,14 @@ def test_hash_subtree_from_prf_key_deterministic() -> None: def test_hash_subtree_from_prf_key_different_indices() -> None: """Tests that different bottom tree indices produce different trees.""" config = TEST_CONFIG - prf_key = TEST_PRF.key_gen() + prf_key = prf_key_gen() parameter = Parameter( data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] ) # Generate two different bottom trees tree0 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, + poseidon=TEST_POSEIDON, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -200,9 +189,7 @@ def test_hash_subtree_from_prf_key_different_indices() -> None: ) tree1 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, + poseidon=TEST_POSEIDON, config=config, prf_key=prf_key, bottom_tree_index=Uint64(1), From 66a7b0ddddaa2484aba2550cd684365a4248bd89 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 17:25:49 +0200 Subject: [PATCH 2/9] refactor(xmss): single-source the bottom-tree width Add a LEAVES_PER_BOTTOM_TREE property to XmssConfig and route the four recomputations of 2^(LOG_LIFETIME / 2) through it. Sign now reuses get_prepared_interval for its prepared-window bound check, removing the duplicated window arithmetic, and drops a redundant int() cast on slot. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lean_spec/subspecs/xmss/constants.py | 5 +++++ src/lean_spec/subspecs/xmss/interface.py | 18 ++++++++---------- src/lean_spec/subspecs/xmss/merkle.py | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index 7f5c193f0..a95d6ec56 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -74,6 +74,11 @@ def LIFETIME(self) -> Uint64: """ return Uint64(1 << self.LOG_LIFETIME) + @property + def LEAVES_PER_BOTTOM_TREE(self) -> int: + """Slots covered by one bottom tree, W = sqrt(LIFETIME) = 2^(LOG_LIFETIME / 2).""" + return 1 << (self.LOG_LIFETIME // 2) + @property def MH_HASH_LEN_FE(self) -> int: """Number of Poseidon output field elements needed for the aborting decode.""" diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index be393beba..bf9ba2211 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -109,7 +109,7 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: start_bottom_tree_index, end_bottom_tree_index = _expand_activation_time( config.LOG_LIFETIME, int(activation_slot), int(num_active_slots) ) - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) + leaves_per_bottom_tree = config.LEAVES_PER_BOTTOM_TREE actual_activation_slot = start_bottom_tree_index * leaves_per_bottom_tree actual_num_active_slots = ( end_bottom_tree_index - start_bottom_tree_index @@ -204,13 +204,11 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: # Phase 1b: prepared bound. # Without two adjacent bottom trees we cannot produce a path without # paying the cost of regenerating them on the fly. - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - prepared_start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree - prepared_end = prepared_start + 2 * leaves_per_bottom_tree - if not (prepared_start <= slot_int < prepared_end): + prepared = self.get_prepared_interval(sk) + if slot_int not in prepared: raise ValueError( f"Slot {slot} is outside the prepared interval " - f"[{prepared_start}, {prepared_end}). " + f"[{prepared.start}, {prepared.stop}). " f"Call advance_preparation() to slide the window forward." ) @@ -247,9 +245,9 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: # Phase 4: combined Merkle path through both trees. # The signed slot picks the bottom tree on the prepared window's left or right. - boundary = (int(sk.left_bottom_tree_index) + 1) * leaves_per_bottom_tree + boundary = prepared.start + config.LEAVES_PER_BOTTOM_TREE bottom_tree = sk.left_bottom_tree if slot_int < boundary else sk.right_bottom_tree - path = combined_path(sk.top_tree, bottom_tree, Uint64(int(slot))) + path = combined_path(sk.top_tree, bottom_tree, Uint64(slot)) return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) @@ -328,7 +326,7 @@ def get_prepared_interval(self, sk: SecretKey) -> range: A signer can sign any slot in this range without paying the cost of rebuilding a bottom tree from the PRF. """ - leaves_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) + leaves_per_bottom_tree = self.config.LEAVES_PER_BOTTOM_TREE start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree return range(start, start + 2 * leaves_per_bottom_tree) @@ -347,7 +345,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: Returns: A secret key with the window shifted by one bottom tree. """ - leaves_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) + leaves_per_bottom_tree = self.config.LEAVES_PER_BOTTOM_TREE left_index = int(sk.left_bottom_tree_index) # Phase 1: no advancement once the activation interval is fully consumed. diff --git a/src/lean_spec/subspecs/xmss/merkle.py b/src/lean_spec/subspecs/xmss/merkle.py index eb924e2cf..7878f1556 100644 --- a/src/lean_spec/subspecs/xmss/merkle.py +++ b/src/lean_spec/subspecs/xmss/merkle.py @@ -283,7 +283,7 @@ def from_prf_key( The requested bottom tree. """ # Each bottom tree covers sqrt(LIFETIME) consecutive epochs. - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) + leaves_per_bottom_tree = config.LEAVES_PER_BOTTOM_TREE start_epoch = bottom_tree_index * Uint64(leaves_per_bottom_tree) end_epoch = start_epoch + Uint64(leaves_per_bottom_tree) From 85b21662b8bcb2a5dd65761b630f19f687f952b4 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 17:53:20 +0200 Subject: [PATCH 3/9] docs(xmss): clearer line-by-line documentation in the scheme interface Rework the docstrings and inline comments of the interface module to follow the project documentation rules: one sentence per line, no backticks, and WHY-focused comments rather than restatements of the code. - _expand_activation_time: conceptual overview of the bottom-tree model plus phase-labeled comments explaining the round-down/round-up bit tricks. - key_gen: lead with the memory-bound rationale; fix the phase numbering so the inline labels match execution order. - sign: explain each phase from the scheme's perspective (synchronized one-time keys, the target-sum layer giving incomparability, Winternitz chain release, and the Merkle opening). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lean_spec/subspecs/xmss/interface.py | 104 +++++++++++++---------- 1 file changed, 61 insertions(+), 43 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index bf9ba2211..5f8ee3f67 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -16,51 +16,52 @@ def _expand_activation_time( log_lifetime: int, desired_activation_slot: int, desired_num_active_slots: int ) -> tuple[int, int]: - """Align a requested activation interval to top-bottom tree boundaries. + """Snap a requested slot window onto whole bottom trees. - Phase 1: round start down to a multiple of sqrt(LIFETIME). - Phase 2: round end up to a multiple of sqrt(LIFETIME). - Phase 3: enforce a minimum duration of two bottom trees. - Phase 4: clamp to the lifetime bound, shifting the interval if needed. + # Overview + + A bottom tree covers C consecutive slots, where C is the square root of the lifetime. + The lifetime is C such trees laid end to end, so C * C slots in total. Args: - log_lifetime: Base-2 logarithm of the lifetime. - desired_activation_slot: First slot requested. - desired_num_active_slots: Number of slots requested. + log_lifetime: Base-2 logarithm of the lifetime in slots. + desired_activation_slot: First slot the caller wants to sign. + desired_num_active_slots: Number of slots the caller wants to sign. Returns: - The pair (start_bottom_tree_index, end_bottom_tree_index). - Actual slots covered are [start * C, end * C) where C = sqrt(LIFETIME). + The half-open bottom-tree index range (start, end). + It covers slots [start * C, end * C). """ - # C = sqrt(LIFETIME). - # c_mask rounds down to multiples of C. + # C is one bottom tree's worth of slots, the square root of the lifetime. + # C is a power of two, so clearing the low bits rounds a slot down to a tree boundary. c = 1 << (log_lifetime // 2) c_mask = ~(c - 1) desired_end_slot = desired_activation_slot + desired_num_active_slots - # Phase 1 + 2: snap the interval endpoints onto bottom-tree boundaries. + # Phase 1: round the start down and the end up onto tree boundaries. + # Adding C - 1 before clearing the low bits rounds the end up rather than down. start = desired_activation_slot & c_mask end = (desired_end_slot + c - 1) & c_mask - # Phase 3: at least two bottom trees so the prepared window always fits. + # Phase 2: widen to two trees so the resident signing window always fits. if end - start < 2 * c: end = start + 2 * c - # Phase 4: clamp to [0, LIFETIME). + # Phase 3: clamp the window into the lifetime. lifetime = c * c if end > lifetime: duration = end - start if duration > lifetime: - # The requested interval is wider than the lifetime. - # Use the whole lifetime. + # The request is wider than the whole lifetime, so cover all of it. start = 0 end = lifetime else: - # Shift the interval back so it ends exactly at the lifetime boundary. + # Slide the window back so it ends exactly at the lifetime boundary. end = lifetime start = (lifetime - duration) & c_mask + # Convert the slot boundaries to bottom-tree indices. return (start // c, end // c) @@ -76,21 +77,20 @@ class GeneralizedXmssScheme(StrictBaseModel): def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: """Generate a fresh key pair active for an aligned slot range. - Phase 1: align the requested interval to sqrt(LIFETIME) boundaries. - Phase 2: draw the master PRF key and public parameter. - Phase 3: materialize the two leftmost bottom trees. - Phase 4: generate every other bottom tree, retaining only its root. - Phase 5: build the top tree from all bottom-tree roots. + # Overview + + The signer keeps only the two leftmost bottom trees resident, plus every bottom-tree root. + That bounds secret-key memory near the square root of the lifetime. - The returned key may cover a wider interval than requested because - of boundary alignment and the two-tree minimum window. + The requested range is snapped outward to whole bottom trees. + So the returned key may cover more slots than asked for. Args: activation_slot: Requested first signable slot. num_active_slots: Requested number of signable slots. Returns: - A KeyPair with both halves of the scheme. + A key pair holding the public root and the resident signer state. Raises: ValueError: When the requested range exceeds the lifetime. @@ -101,19 +101,18 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: if int(activation_slot) + int(num_active_slots) > int(config.LIFETIME): raise ValueError("Activation range exceeds the key's lifetime.") - # Phase 2: draw the master secret and the public parameter. + # Phase 1: draw the master secret and the public parameter. parameter = random_parameter(config) prf_key = prf_key_gen() - # Phase 1: align onto bottom-tree boundaries. + # Phase 2: align the requested interval onto bottom-tree boundaries. start_bottom_tree_index, end_bottom_tree_index = _expand_activation_time( config.LOG_LIFETIME, int(activation_slot), int(num_active_slots) ) - leaves_per_bottom_tree = config.LEAVES_PER_BOTTOM_TREE - actual_activation_slot = start_bottom_tree_index * leaves_per_bottom_tree + actual_activation_slot = start_bottom_tree_index * config.LEAVES_PER_BOTTOM_TREE actual_num_active_slots = ( end_bottom_tree_index - start_bottom_tree_index - ) * leaves_per_bottom_tree + ) * config.LEAVES_PER_BOTTOM_TREE # Phase 3: build the two leftmost bottom trees and keep them resident. left_bottom_tree = HashSubTree.from_prf_key( @@ -157,6 +156,7 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: bottom_tree_roots=bottom_tree_roots, ) + # Pack the public root and the resident signer state into the key pair. pk = PublicKey(root=top_tree.root(), parameter=parameter) sk = SecretKey( prf_key=prf_key, @@ -197,13 +197,18 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: slot_int = int(slot) activation_int = int(sk.activation_slot) - # Phase 1a: activation bound. + # Phase 1a: the slot must lie in the key's activation range. + # + # This is a synchronized one-time scheme: each slot consumes a distinct one-time key. + # The key only holds material for the contiguous range fixed at generation. if not (activation_int <= slot_int < activation_int + int(sk.num_active_slots)): raise ValueError("Key is not active for the specified slot.") - # Phase 1b: prepared bound. - # Without two adjacent bottom trees we cannot produce a path without - # paying the cost of regenerating them on the fly. + # Phase 1b: the slot must lie in the prepared window. + # + # The signature opens this slot's leaf through the bottom tree that holds it. + # Only the two adjacent resident bottom trees are available. + # A slot outside them would force rebuilding a tree from the PRF, so we refuse. prepared = self.get_prepared_interval(sk) if slot_int not in prepared: raise ValueError( @@ -212,8 +217,14 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: f"Call advance_preparation() to slide the window forward." ) - # Phase 2: deterministic search for valid randomness. - # Randomness comes from the PRF so signing is reproducible. + # Phase 2: find randomness whose encoding lands on the target-sum layer. + # + # A valid codeword must have digits summing to the target. + # That constant sum is what makes distinct codewords incomparable, hence unforgeable. + # The message hash hits the layer only for some randomness, so resample until it does. + # + # The randomness is derived from the PRF, keyed by the message and an attempt counter. + # Signing the same message twice is therefore reproducible. for attempts in range(config.MAX_TRIES): rho = prf_get_randomness(config, sk.prf_key, slot, message, Uint64(attempts)) codeword = target_sum_encode(self.poseidon, config, sk.parameter, message, rho, slot) @@ -224,11 +235,15 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: f"Failed to find a valid message encoding after {config.MAX_TRIES} tries." ) - # Sanity guard against an encoder returning the wrong number of digits. + # A valid codeword carries exactly one digit per Winternitz chain. if len(codeword) != config.DIMENSION: raise RuntimeError("Encoding is broken: returned too many or too few chunks.") - # Phase 3: walk each Winternitz chain to the released hash. + # Phase 3: release each Winternitz chain at the position its digit selects. + # + # Every chain starts from a secret derived from the PRF. + # Hashing that start forward by the digit gives the value to reveal. + # The verifier later finishes the remaining steps to reach the chain end. ots_hashes: list[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): start_digest = prf_apply(config, sk.prf_key, slot, Uint64(chain_index)) @@ -243,12 +258,16 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: ) ots_hashes.append(ots_digest) - # Phase 4: combined Merkle path through both trees. - # The signed slot picks the bottom tree on the prepared window's left or right. + # Phase 4: open this slot's leaf up to the public root. + # + # The opening climbs the bottom tree that holds the slot, then the top tree. + # The slot's side of the prepared window selects which resident bottom tree to climb. boundary = prepared.start + config.LEAVES_PER_BOTTOM_TREE bottom_tree = sk.left_bottom_tree if slot_int < boundary else sk.right_bottom_tree path = combined_path(sk.top_tree, bottom_tree, Uint64(slot)) + # The signature carries the opening, the randomness, and the released chain values. + # The randomness lets the verifier recompute the same codeword. return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> bool: @@ -345,11 +364,10 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: Returns: A secret key with the window shifted by one bottom tree. """ - leaves_per_bottom_tree = self.config.LEAVES_PER_BOTTOM_TREE left_index = int(sk.left_bottom_tree_index) # Phase 1: no advancement once the activation interval is fully consumed. - next_prepared_end_slot = (left_index + 3) * leaves_per_bottom_tree + next_prepared_end_slot = (left_index + 3) * self.config.LEAVES_PER_BOTTOM_TREE activation_end = int(sk.activation_slot) + int(sk.num_active_slots) if next_prepared_end_slot > activation_end: return sk From 299a5f022635d071703909f38e2d48a3b15f029b Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 18:07:01 +0200 Subject: [PATCH 4/9] refactor(xmss): build bottom trees directly to their own height Give the subtree builder an optional highest_layer bound so a bottom tree stops at depth/2 instead of building the full tree to the global root and discarding the upper half. The single-node root is then taken from the top built layer, keeping the absolute-index selection that handles odd bottom tree indices. The constructed layers are byte-identical to before. Also simplify the authentication-path index math to plain int arithmetic, matching the verifier, and drop the redundant Uint64 wraps. Co-Authored-By: Claude Opus 4.7 (1M context) --- .claude/worktrees/poseidon1-refactor | 1 + src/lean_spec/subspecs/xmss/merkle.py | 44 ++++++++++++++----------- src/lean_spec/subspecs/xmss/poseidon.py | 10 ++---- 3 files changed, 27 insertions(+), 28 deletions(-) create mode 160000 .claude/worktrees/poseidon1-refactor diff --git a/.claude/worktrees/poseidon1-refactor b/.claude/worktrees/poseidon1-refactor new file mode 160000 index 000000000..548d8719e --- /dev/null +++ b/.claude/worktrees/poseidon1-refactor @@ -0,0 +1 @@ +Subproject commit 548d8719e875ce2dc25d9fac6d7557d8a570db6e diff --git a/src/lean_spec/subspecs/xmss/merkle.py b/src/lean_spec/subspecs/xmss/merkle.py index 7878f1556..23b52bec9 100644 --- a/src/lean_spec/subspecs/xmss/merkle.py +++ b/src/lean_spec/subspecs/xmss/merkle.py @@ -95,12 +95,13 @@ def new( start_index: Uint64, parameter: Parameter, lowest_layer_nodes: list[HashDigestVector], + highest_layer: Uint64 | None = None, ) -> Self: - """Build a subtree from its lowest layer up to the root. + """Build a subtree from its lowest layer up to a bounding layer. Phase 1: pad the input layer to the alignment invariant. Phase 2: hash each sibling pair to produce the next layer up. - Phase 3: pad each new layer and continue to the root. + Phase 3: pad each new layer and continue to the bounding layer. Args: poseidon: Cached Poseidon1 engine. @@ -110,10 +111,14 @@ def new( start_index: Absolute index of the first input node. parameter: Public parameter for the hash function. lowest_layer_nodes: Active nodes at the lowest layer. + highest_layer: Layer to stop building at, defaulting to the full depth. Returns: - A subtree containing every layer from lowest_layer to the root. + A subtree containing every layer from lowest_layer up to highest_layer. """ + # Build to the global root unless a lower bounding layer is requested. + highest_layer = depth if highest_layer is None else highest_layer + # The input nodes must fit in the layer they belong to. max_positions = 1 << int(depth - lowest_layer) if int(start_index) + len(lowest_layer_nodes) > max_positions: @@ -128,7 +133,7 @@ def new( layers.append(current) # Phases 2 + 3: hash sibling pairs, pad, repeat. - for level in range(lowest_layer, depth): + for level in range(lowest_layer, highest_layer): parent_start = current.start_index // Uint64(2) parents = [ poseidon.tweak_hash( @@ -202,9 +207,8 @@ def new_bottom_tree( ) -> Self: """Build one bottom tree from leaf hashes up to its standalone root. - Phase 1: build a full subtree from layer 0 using the provided leaves. - Phase 2: drop the layers above depth/2 produced by extra padding. - Phase 3: replace the highest layer with a single-node root extracted from middle. + Phase 1: build the layers from 0 up to the bottom-tree root layer. + Phase 2: replace that padded top layer with its single-node root. Args: poseidon: Cached Poseidon1 engine. @@ -230,8 +234,8 @@ def new_bottom_tree( f"Expected {leaves_per_tree} leaves for depth={depth}, got {len(leaves)}." ) - # Phase 1: build a full subtree from layer 0. - full_tree = cls.new( + # Phase 1: build only layers 0 through depth/2, the bottom tree's own height. + subtree = cls.new( poseidon=poseidon, config=config, lowest_layer=Uint64(0), @@ -239,22 +243,22 @@ def new_bottom_tree( start_index=bottom_tree_index * Uint64(leaves_per_tree), parameter=parameter, lowest_layer_nodes=leaves, + highest_layer=Uint64(depth // 2), ) - # Phase 3: extract the middle layer's root entry for this bottom tree. - middle = full_tree.layers[depth // 2] - root_idx = int(bottom_tree_index - middle.start_index) + # Phase 2: the top built layer is padded to a sibling pair. + # The real root is the node at this tree's index, not always position zero. + # An odd index leaves a random pad at position zero, so select by absolute index. + top = subtree.layers[-1] + root_idx = int(bottom_tree_index - top.start_index) root_layer = HashTreeLayer( start_index=bottom_tree_index, - nodes=HashDigestList(data=[middle.nodes[root_idx]]), + nodes=HashDigestList(data=[top.nodes[root_idx]]), ) - - # Phase 2 + 3: keep layers 0 through depth/2 - 1, then append the standalone root. - truncated = list(full_tree.layers[: depth // 2]) return cls( depth=Uint64(depth), lowest_layer=Uint64(0), - layers=HashTreeLayers(data=truncated + [root_layer]), + layers=HashTreeLayers(data=list(subtree.layers[:-1]) + [root_layer]), ) @classmethod @@ -353,18 +357,18 @@ def path(self, position: Uint64) -> HashTreeOpening: raise ValueError(f"Position {position} out of bounds.") siblings: list[HashDigestVector] = [] - pos = position + pos = int(position) # Stop one short of the root layer. # The root has no sibling. for layer in self.layers[:-1]: # The sibling sits at the position with the last bit flipped, then we # rebase by the layer's start_index because the layer is sparse. - sibling_idx = int((pos ^ Uint64(1)) - layer.start_index) + sibling_idx = (pos ^ 1) - int(layer.start_index) if not (0 <= sibling_idx < len(layer.nodes)): raise ValueError(f"Sibling index {sibling_idx} out of bounds.") siblings.append(layer.nodes[sibling_idx]) - pos = pos // Uint64(2) + pos //= 2 return HashTreeOpening(siblings=HashDigestList(data=siblings)) diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index c340a9a2c..b14962937 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -1,9 +1,4 @@ -"""Poseidon1 hash engine in compression and sponge modes for the Generalized XMSS scheme. - -Poseidon1 is arithmetization-friendly. -Hashing across hash chains, Merkle nodes, and Merkle leaves uses a single -permutation, which keeps the in-SNARK aggregation step cheap. -""" +"""Poseidon1 hash engine in compression and sponge modes for the Generalized XMSS scheme.""" from itertools import batched @@ -267,5 +262,4 @@ def hash_chain( """Poseidon1 engine with production parameters.""" TEST_POSEIDON = PROD_POSEIDON -"""Test environment reuses the production Poseidon1 parameters. -Only the surrounding configuration differs between modes.""" +"""Test environment reuses the production Poseidon1 parameters.""" From 6a46a77540abd67bcc546919c003515f585ec1eb Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 18:08:02 +0200 Subject: [PATCH 5/9] chore: stop tracking local agent worktree gitlink A parallel worktree under .claude/worktrees was accidentally staged and committed as a gitlink. Remove it from tracking and ignore the directory. Co-Authored-By: Claude Opus 4.7 (1M context) --- .claude/worktrees/poseidon1-refactor | 1 - .gitignore | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) delete mode 160000 .claude/worktrees/poseidon1-refactor diff --git a/.claude/worktrees/poseidon1-refactor b/.claude/worktrees/poseidon1-refactor deleted file mode 160000 index 548d8719e..000000000 --- a/.claude/worktrees/poseidon1-refactor +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 548d8719e875ce2dc25d9fac6d7557d8a570db6e diff --git a/.gitignore b/.gitignore index 420842c6d..705422e42 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ scripts/ # Client codebase used by Claude skill for running reference tests clients/ + +# Local agent worktrees +.claude/worktrees/ From f60ff044401fc4d99c335c627a9d27002fbeb212 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 18:54:38 +0200 Subject: [PATCH 6/9] docs(xmss): clarify merkle module and colocate the tree layer types Rewrite the merkle documentation to the project doc standard: a module header explaining the top-bottom split and sliding-window memory bound, overview-only docstrings backed by phase-labeled bodies, and concrete layout and climb traces. Move the signer's internal tree representation next to its only user: the layer container, its list, and the padding logic now live in the merkle module as a named constructor, while the signature-facing opening type stays among the base types. Inline the single-use layer cap. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lean_spec/subspecs/xmss/merkle.py | 287 +++++++++++++----- src/lean_spec/subspecs/xmss/types.py | 28 -- .../lstar/ssz/test_xmss_containers.py | 2 +- 3 files changed, 208 insertions(+), 109 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/merkle.py b/src/lean_spec/subspecs/xmss/merkle.py index 23b52bec9..a0f2cad85 100644 --- a/src/lean_spec/subspecs/xmss/merkle.py +++ b/src/lean_spec/subspecs/xmss/merkle.py @@ -1,28 +1,45 @@ -"""Sparse Merkle subtrees for top-bottom XMSS traversal. +r""" +Sparse Merkle subtrees for the top-bottom traversal of an XMSS key. -The XMSS lifetime tree is split into one top tree and many bottom trees. -Each bottom tree covers sqrt(LIFETIME) consecutive slots. -The signer keeps the full top tree plus two adjacent bottom trees resident, -forming a sliding window of 2*sqrt(LIFETIME) signable slots. +# Overview -This bounds the secret-key memory at O(sqrt(LIFETIME)) instead of O(LIFETIME). +The long-lived public key of an XMSS signature is the root of one Merkle tree. +That tree commits to one one-time public key per slot of the key's lifetime. +A signature opens the leaf for its slot with a path of sibling hashes up to the root. + +A full lifetime tree has one leaf per slot. +For a lifetime of 2^32 slots, holding every node in memory is infeasible. + +# The top-bottom split + +The tree is cut into one top tree sitting above many bottom trees. +Each bottom tree covers a contiguous run of leaves, the square root of the lifetime of them. +There are that many bottom trees, and their roots are exactly the leaves of the top tree. + +The signer keeps the whole top tree resident, plus the bottom trees around the active slot. +As the active slot advances, stale bottom trees are dropped and fresh ones regenerated on demand. +Resident memory stays near the square root of the lifetime instead of the full lifetime. + +# Sparse layers + +Only a window of leaves is resident, so a stored layer holds a contiguous slice, not a full level. +Each layer records the absolute index of its first node, keeping positions in full-tree coordinates. """ from itertools import batched from typing import Self from lean_spec.types import Uint64 +from lean_spec.types.collections import SSZList from lean_spec.types.container import Container -from .constants import XmssConfig +from .constants import TARGET_CONFIG, XmssConfig from .field import random_domain from .poseidon import PoseidonXmss from .prf import prf_apply from .types import ( HashDigestList, HashDigestVector, - HashTreeLayer, - HashTreeLayers, HashTreeOpening, Parameter, PRFKey, @@ -30,60 +47,124 @@ ) -def _padded_layer( - config: XmssConfig, - nodes: list[HashDigestVector], - start_index: Uint64, -) -> HashTreeLayer: - """Pad a layer so every node has a sibling and parent generation has no edge cases. +class HashTreeLayer(Container): + """ + A single horizontal slice of a sparse Merkle subtree. - Invariant: the padded layer starts at an even index and ends at an odd index. - A single-node layer is allowed when the layer is the root. + The tree is sparse, so a layer stores only the nodes computed for the active leaf range. """ - nodes_with_padding: list[HashDigestVector] = [] - end_index = start_index + Uint64(len(nodes)) - Uint64(1) - # Prepend one random sibling when the layer begins on an odd index. - if start_index % Uint64(2) == Uint64(1): - nodes_with_padding.append(random_domain(config)) + start_index: Uint64 + """Absolute index of the first stored node within its level.""" + + nodes: HashDigestList + """Stored hash digests for this layer, ordered left to right.""" - # The padded layer always starts on the even index at or before start_index. - actual_start_index = start_index - (start_index % Uint64(2)) + @classmethod + def padded( + cls, + config: XmssConfig, + nodes: list[HashDigestVector], + start_index: Uint64, + ) -> Self: + """ + Build a layer whose nodes can all be paired at the next level up. - nodes_with_padding.extend(nodes) + # Why pad - # Append one random sibling when the layer ends on an even index. - if end_index % Uint64(2) == Uint64(0): - nodes_with_padding.append(random_domain(config)) + The level above pairs nodes two at a time, then hashes each pair. + - A run starting on an odd index lacks a left neighbor for its first node. + - A run ending on an even index lacks a right neighbor for its last node. - return HashTreeLayer( - start_index=actual_start_index, - nodes=HashDigestList(data=nodes_with_padding), - ) + Padding either gap with a fresh random digest lets every node pair. + + # Invariant + + The result starts on an even index and ends on an odd index. + A single-node layer is the sole exception, since it is a subtree root. + + # Layout + + indices 5, 6, 7 -> [pad] 5 6 7 starts 4, ends 7 + indices 4, 5, 6 -> 4 5 6 [pad] starts 4, ends 7 + + Args: + config: Active XMSS configuration. + nodes: Active nodes at this layer, in ascending index order. + start_index: Absolute index of the first active node. + + Returns: + A layer satisfying the alignment invariant above. + """ + nodes_with_padding: list[HashDigestVector] = [] + end_index = start_index + Uint64(len(nodes)) - Uint64(1) + + # Prepend one random sibling when the layer begins on an odd index. + if start_index % Uint64(2) == Uint64(1): + nodes_with_padding.append(random_domain(config)) + + # The padded layer always starts on the even index at or before start_index. + actual_start_index = start_index - (start_index % Uint64(2)) + + nodes_with_padding.extend(nodes) + + # Append one random sibling when the layer ends on an even index. + if end_index % Uint64(2) == Uint64(0): + nodes_with_padding.append(random_domain(config)) + + return cls( + start_index=actual_start_index, + nodes=HashDigestList(data=nodes_with_padding), + ) + + +class HashTreeLayers(SSZList[HashTreeLayer]): + """ + The layers of a subtree, ordered from the lowest layer up to the root. + + A bottom tree and a top tree each cover half the depth. + The cap admits the full lifetime tree. + """ + + LIMIT = TARGET_CONFIG.LOG_LIFETIME + 1 + """Layers run from level zero, the leaves, up to the lifetime depth, the root, inclusive.""" class HashSubTree(Container): - """Sparse Merkle subtree of an XMSS lifetime tree. + """ + A contiguous slice of an XMSS lifetime tree, stored layer by layer. - Stores layers from lowest_layer up to the subtree root. - A bottom tree has lowest_layer = 0 and covers a window of leaves. - A top tree has lowest_layer = LOG_LIFETIME/2 and covers the bottom-tree roots. + # Overview - Layout invariant: every active layer starts on an even index and ends on - an odd index except for the single-node root layer. + A subtree holds every node from its lowest layer up to a single root node. + Two shapes exist, told apart by where the lowest layer sits. + + - A bottom tree starts at layer zero and covers a window of leaves. + - A top tree starts at the split layer and covers the bottom-tree roots. + + # Invariant + + Every stored layer starts on an even index and ends on an odd index. + The exception is the top layer, which holds the single subtree root. """ depth: Uint64 """Depth of the full lifetime tree this subtree belongs to. - A subtree starting at layer k stores depth - k layers.""" + + A subtree starting at layer k stores depth - k layers. + """ lowest_layer: Uint64 - """Lowest layer included in this subtree. - Zero for bottom trees, LOG_LIFETIME/2 for top trees.""" + """Lowest layer included in this subtree: + - Zero for bottom trees, + - LOG_LIFETIME/2 for top trees. + """ layers: HashTreeLayers """Layers stored from lowest_layer up to the subtree root. - The last entry holds a single node, the subtree root.""" + + The last entry holds a single node, the subtree root. + """ @classmethod def new( @@ -97,11 +178,14 @@ def new( lowest_layer_nodes: list[HashDigestVector], highest_layer: Uint64 | None = None, ) -> Self: - """Build a subtree from its lowest layer up to a bounding layer. + """ + Build a subtree from its lowest layer up to a bounding layer. - Phase 1: pad the input layer to the alignment invariant. - Phase 2: hash each sibling pair to produce the next layer up. - Phase 3: pad each new layer and continue to the bounding layer. + # Overview + + Each layer is hashed pairwise into the layer above, climbing one level per step. + Padding keeps every intermediate layer aligned so sibling pairs form cleanly. + Building stops at the bounding layer, which defaults to the global root. Args: poseidon: Cached Poseidon1 engine. @@ -114,7 +198,10 @@ def new( highest_layer: Layer to stop building at, defaulting to the full depth. Returns: - A subtree containing every layer from lowest_layer up to highest_layer. + A subtree holding every layer from the lowest layer up to the bounding layer. + + Raises: + ValueError: When the input nodes do not fit the level they start in. """ # Build to the global root unless a lower bounding layer is requested. highest_layer = depth if highest_layer is None else highest_layer @@ -129,7 +216,7 @@ def new( # Phase 1: pad the input layer. layers: list[HashTreeLayer] = [] - current = _padded_layer(config, lowest_layer_nodes, start_index) + current = HashTreeLayer.padded(config, lowest_layer_nodes, start_index) layers.append(current) # Phases 2 + 3: hash sibling pairs, pad, repeat. @@ -144,7 +231,7 @@ def new( ) for i, (left, right) in enumerate(batched(current.nodes, 2)) ] - current = _padded_layer(config, parents, parent_start) + current = HashTreeLayer.padded(config, parents, parent_start) layers.append(current) return cls( @@ -205,10 +292,13 @@ def new_bottom_tree( parameter: Parameter, leaves: list[HashDigestVector], ) -> Self: - """Build one bottom tree from leaf hashes up to its standalone root. + """ + Build one bottom tree from its leaf hashes up to its standalone root. - Phase 1: build the layers from 0 up to the bottom-tree root layer. - Phase 2: replace that padded top layer with its single-node root. + # Overview + + A bottom tree spans the lower half of the lifetime tree for one window of slots. + Its root is later placed as a single leaf of the top tree. Args: poseidon: Cached Poseidon1 engine. @@ -219,10 +309,11 @@ def new_bottom_tree( leaves: Pre-hashed one-time public keys for this bottom tree's slots. Returns: - A subtree with layers 0 through depth/2 ending in the bottom-tree root. + A subtree spanning the lower half of the tree, ending in the bottom-tree root. Raises: - ValueError: When depth is odd or the leaf count does not match sqrt(LIFETIME). + ValueError: When the depth is odd, or the leaf count is not the square + root of the lifetime. """ if depth % 2 != 0: raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") @@ -270,11 +361,20 @@ def from_prf_key( bottom_tree_index: Uint64, parameter: Parameter, ) -> Self: - """Regenerate one bottom tree on demand from the master PRF key. + """ + Regenerate one bottom tree on demand from the master secret seed. + + # Overview + + The secret key is not stored slot by slot. + One short master seed deterministically expands into every chain start. + This lets the signer keep only a sliding window resident and rebuild the rest on demand. + + # What a leaf is - Phase 1: for every epoch in the bottom tree, derive chain starts via PRF. - Phase 2: hash each chain for BASE - 1 steps to obtain the chain endpoints. - Phase 3: hash chain endpoints into a leaf, then build the bottom tree. + Each slot owns one one-time signature made of many independent hash chains. + Walking a chain from its secret start to its far end yields one public chain end. + Hashing all of a slot's chain ends together produces the leaf committed at that slot. Args: poseidon: Cached Poseidon1 engine. @@ -293,7 +393,8 @@ def from_prf_key( leaf_hashes: list[HashDigestVector] = [] for epoch in range(start_epoch, end_epoch): - # Phases 1 + 2: derive each chain start, then walk it to the public endpoint. + # Derive each chain start from the seed, then walk it to its public end. + # The far end is the chain start hashed forward the full length minus one. chain_ends: list[HashDigestVector] = [] for chain_index in range(config.DIMENSION): start_digest = prf_apply(config, prf_key, Uint64(epoch), Uint64(chain_index)) @@ -308,7 +409,7 @@ def from_prf_key( ) chain_ends.append(end_digest) - # Phase 3: hash all chain endpoints into the leaf for this epoch. + # The leaf for this slot is the hash of all its chain ends together. leaf_tweak = TreeTweak(level=0, index=Uint64(epoch)) leaf_hash = poseidon.tweak_hash(config, parameter, leaf_tweak, chain_ends) leaf_hashes.append(leaf_hash) @@ -335,16 +436,29 @@ def root(self) -> HashDigestVector: return self.layers[-1].nodes[0] def path(self, position: Uint64) -> HashTreeOpening: - """Build the authentication path from a leaf up to the subtree root. + """ + Collect the sibling hashes that connect one leaf to the subtree root. + + # Overview + + At each level the node has exactly one sibling, the node sharing its parent. + That sibling sits at the current position with its lowest bit flipped. + Recording one sibling per level, then halving the position to climb, yields the full path. + The root has no sibling, so the walk stops one level below it. + + # Layout - For a subtree covering layers L through H, the opening contains H - L siblings, - one per layer between L and H - 1. + climbing from leaf position 5 in a three-level subtree: + + position 5 -> sibling 4 (flip low bit), then halve to 2 + position 2 -> sibling 3 (flip low bit), then halve to 1 + position 1 -> root, no sibling, stop Args: - position: Absolute index of the leaf in the full tree coordinate system. + position: Absolute index of the leaf in full-tree coordinates. Returns: - An opening of sibling hashes from bottom to top. + An opening of sibling hashes ordered from the leaf upward. Raises: ValueError: When the subtree is empty or the position is out of bounds. @@ -378,11 +492,18 @@ def combined_path( bottom_tree: HashSubTree, position: Uint64, ) -> HashTreeOpening: - """Concatenate the bottom-tree and top-tree openings for one leaf. + """ + Stitch a bottom-tree opening and a top-tree opening into one full path. + + # Overview + + A signature authenticates its leaf all the way up to the global root. + No single resident subtree spans that whole distance, so two openings are joined. - A signature must authenticate the leaf against the global root. - The bottom opening proves leaf membership in its bottom tree. - The top opening proves the bottom-tree root sits under the global root. + # Proof flow + + bottom opening : proves the leaf sits under its bottom-tree root. + top opening : proves that bottom-tree root sits under the global root. Args: top_tree: The top tree containing the global root. @@ -390,11 +511,11 @@ def combined_path( position: Absolute index of the leaf. Returns: - An opening with depth siblings authenticating the leaf against the global root. + One opening that authenticates the leaf against the global root. Raises: - ValueError: When tree depths mismatch, depth is odd, or position is out - of bounds for the supplied bottom tree. + ValueError: When the tree depths disagree, the depth is odd, or the position + does not belong to the supplied bottom tree. """ if top_tree.depth != bottom_tree.depth: raise ValueError(f"Depth mismatch: top={top_tree.depth}, bottom={bottom_tree.depth}.") @@ -412,8 +533,7 @@ def combined_path( f"got {bottom_tree.layers[0].start_index}." ) - # Bottom path proves leaf -> bottom-tree root. - # Top path proves bottom root -> global root. + # The opening climbs from leaf to root, so bottom siblings come before top siblings. bottom_path = bottom_tree.path(position) top_path = top_tree.path(position // leaves_per_tree) combined = tuple(bottom_path.siblings.data) + tuple(top_path.siblings.data) @@ -430,13 +550,20 @@ def verify_path( leaf_parts: list[HashDigestVector], opening: HashTreeOpening, ) -> bool: - """Verify a Merkle opening against a trusted root. + """ + Recompute a root from a leaf and its opening, then compare against a trusted root. + + # Overview + + Verification mirrors construction in reverse. + The leaf is hashed, then folded with each sibling while climbing one level per step. + The walk succeeds when the recomputed root equals the trusted root. - Phase 1: hash leaf_parts into the leaf digest. - Phase 2: walk the opening, hashing the current node with each sibling. - Phase 3: compare the reconstructed root with the trusted one. + # Why return false instead of raising - Returns False on attacker-controlled invalid input instead of raising. + The opening arrives inside an untrusted signature. + A malformed opening must be a quiet verification failure, never a crash. + So out-of-range input returns false rather than raising. Args: poseidon: Cached Poseidon1 engine. @@ -448,7 +575,7 @@ def verify_path( opening: Sibling path from leaf to root. Returns: - True when the path reconstructs the root, False otherwise. + True when the path reconstructs the trusted root, false otherwise. """ # Guard against malformed openings. # The opening list caps at 32 entries. diff --git a/src/lean_spec/subspecs/xmss/types.py b/src/lean_spec/subspecs/xmss/types.py index 62efae62a..ed88836d6 100644 --- a/src/lean_spec/subspecs/xmss/types.py +++ b/src/lean_spec/subspecs/xmss/types.py @@ -109,31 +109,3 @@ class HashTreeOpening(Container): siblings: HashDigestList """SSZ-compliant list of sibling hashes, from bottom to top.""" - - -class HashTreeLayer(Container): - """A single horizontal slice of the sparse Merkle tree. - - The tree is sparse: only nodes computed for the active leaf range are stored. - """ - - start_index: Uint64 - """The starting index of the first node in this layer.""" - nodes: HashDigestList - """SSZ-compliant list of hash digests stored for this layer.""" - - -# Why: layers run from 0 (leaves) up to LOG_LIFETIME (root) inclusive. -LAYERS_LIMIT: Final = TARGET_CONFIG.LOG_LIFETIME + 1 -"""Maximum number of layers in a subtree.""" - - -class HashTreeLayers(SSZList[HashTreeLayer]): - """Variable-length list of Merkle tree layers. - - Represents the layers of a subtree, from the lowest layer up to the root. - Bottom and top trees each cover half the depth. - The cap allows the full tree. - """ - - LIMIT = LAYERS_LIMIT diff --git a/tests/consensus/lstar/ssz/test_xmss_containers.py b/tests/consensus/lstar/ssz/test_xmss_containers.py index 4b10ded57..a615be188 100644 --- a/tests/consensus/lstar/ssz/test_xmss_containers.py +++ b/tests/consensus/lstar/ssz/test_xmss_containers.py @@ -10,11 +10,11 @@ TypeOneMultiSignature, TypeTwoMultiSignature, ) +from lean_spec.subspecs.xmss.merkle import HashTreeLayer from lean_spec.subspecs.xmss.types import ( HASH_DIGEST_LENGTH, HashDigestList, HashDigestVector, - HashTreeLayer, HashTreeOpening, Parameter, ) From b9a5ad90a0ca0e51de310790f8f31d1f39aaefb5 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 19:26:45 +0200 Subject: [PATCH 7/9] refactor(xmss): colocate PRFKey with the PRF, tighten base types Move the master key type next to the pseudorandom function and expose its derivations as methods: a fresh-key constructor, a chain-start derivation, and a signing-randomness derivation. Update all call sites accordingly. Tighten the base types: type the PRF domain separator as a fixed-length byte string, inline the single-use digest-length alias, and express the layer node cap in terms of the leaves-per-bottom-tree property. Clarify the docs across the PRF and base-type modules per the project doc rules. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lean_spec/subspecs/xmss/containers.py | 2 +- src/lean_spec/subspecs/xmss/interface.py | 8 +- src/lean_spec/subspecs/xmss/merkle.py | 7 +- src/lean_spec/subspecs/xmss/prf.py | 222 ++++++++++-------- src/lean_spec/subspecs/xmss/types.py | 60 ++--- .../lstar/ssz/test_xmss_containers.py | 12 +- tests/lean_spec/subspecs/xmss/test_prf.py | 17 +- tests/lean_spec/subspecs/xmss/test_utils.py | 8 +- 8 files changed, 179 insertions(+), 157 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index 4db8dac77..06aaef326 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -12,12 +12,12 @@ from ...types.exceptions import SSZError from .constants import TARGET_CONFIG from .merkle import HashSubTree +from .prf import PRFKey from .types import ( HashDigestList, HashDigestVector, HashTreeOpening, Parameter, - PRFKey, Randomness, ) diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 5f8ee3f67..361863456 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -9,7 +9,7 @@ from .field import random_parameter from .merkle import HashSubTree, combined_path, verify_path from .poseidon import PROD_POSEIDON, TEST_POSEIDON, PoseidonXmss -from .prf import prf_apply, prf_get_randomness, prf_key_gen +from .prf import PRFKey from .types import HashDigestList, HashDigestVector @@ -103,7 +103,7 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: # Phase 1: draw the master secret and the public parameter. parameter = random_parameter(config) - prf_key = prf_key_gen() + prf_key = PRFKey.generate() # Phase 2: align the requested interval onto bottom-tree boundaries. start_bottom_tree_index, end_bottom_tree_index = _expand_activation_time( @@ -226,7 +226,7 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: # The randomness is derived from the PRF, keyed by the message and an attempt counter. # Signing the same message twice is therefore reproducible. for attempts in range(config.MAX_TRIES): - rho = prf_get_randomness(config, sk.prf_key, slot, message, Uint64(attempts)) + rho = sk.prf_key.derive_randomness(config, slot, message, Uint64(attempts)) codeword = target_sum_encode(self.poseidon, config, sk.parameter, message, rho, slot) if codeword is not None: break @@ -246,7 +246,7 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: # The verifier later finishes the remaining steps to reach the chain end. ots_hashes: list[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): - start_digest = prf_apply(config, sk.prf_key, slot, Uint64(chain_index)) + start_digest = sk.prf_key.derive_chain_start(config, slot, Uint64(chain_index)) ots_digest = self.poseidon.hash_chain( config=config, parameter=sk.parameter, diff --git a/src/lean_spec/subspecs/xmss/merkle.py b/src/lean_spec/subspecs/xmss/merkle.py index a0f2cad85..c581eb86b 100644 --- a/src/lean_spec/subspecs/xmss/merkle.py +++ b/src/lean_spec/subspecs/xmss/merkle.py @@ -36,13 +36,12 @@ from .constants import TARGET_CONFIG, XmssConfig from .field import random_domain from .poseidon import PoseidonXmss -from .prf import prf_apply +from .prf import PRFKey from .types import ( HashDigestList, HashDigestVector, HashTreeOpening, Parameter, - PRFKey, TreeTweak, ) @@ -397,7 +396,9 @@ def from_prf_key( # The far end is the chain start hashed forward the full length minus one. chain_ends: list[HashDigestVector] = [] for chain_index in range(config.DIMENSION): - start_digest = prf_apply(config, prf_key, Uint64(epoch), Uint64(chain_index)) + start_digest = prf_key.derive_chain_start( + config, Uint64(epoch), Uint64(chain_index) + ) end_digest = poseidon.hash_chain( config=config, parameter=parameter, diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index fc086ebdf..426b2a40a 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -1,23 +1,23 @@ -"""SHAKE128-based pseudorandom function for deterministic key derivation. - -Derives hash-chain starts and signing randomness from one master key. -Every call is domain-separated so the same key never collides across contexts. -""" +"""SHAKE128-based pseudorandom function for deterministic key derivation.""" import hashlib import os from itertools import batched -from typing import Final +from typing import Final, Self -from lean_spec.types import Bytes32, Uint64 +from lean_spec.types import Bytes16, Bytes32, Uint64 +from lean_spec.types.byte_arrays import BaseBytes from ..koalabear import Fp from .constants import PRF_KEY_LENGTH, XmssConfig -from .types import HashDigestVector, PRFKey, Randomness +from .types import HashDigestVector, Randomness + +PRF_DOMAIN_SEP: Final = Bytes16(b"\xae\xae\x22\xff\x00\x01\xfa\xff\x21\xaf\x12\x00\x01\x11\xff\x00") +""" +Fixed domain separator prefixed to every PRF call. -PRF_DOMAIN_SEP: Final[bytes] = b"\xae\xae\x22\xff\x00\x01\xfa\xff\x21\xaf\x12\x00\x01\x11\xff\x00" -"""Fixed 16-byte domain separator used by every PRF call. -Prevents cross-context collisions if SHAKE128 is reused elsewhere in the system.""" +Prevents cross-context collisions if SHAKE128 is reused elsewhere in the system. +""" PRF_DOMAIN_SEP_DOMAIN_ELEMENT: Final[bytes] = b"\x00" """Subdomain tag for hash-chain start derivation.""" @@ -26,95 +26,123 @@ """Subdomain tag for signing-randomness derivation.""" PRF_BYTES_PER_FE: Final[int] = 16 -"""SHAKE128 bytes consumed per output field element. -128 bits reduced modulo a 31-bit prime gives a statistical margin against bias.""" - +""" +SHAKE128 bytes consumed per output field element. -def prf_key_gen() -> PRFKey: - """Generate a fresh master PRF key from the operating system entropy pool.""" - return PRFKey(os.urandom(PRF_KEY_LENGTH)) +128 bits reduced modulo a 31-bit prime gives a statistical margin against bias. +""" -def prf_apply( - config: XmssConfig, key: PRFKey, epoch: Uint64, chain_index: Uint64 -) -> HashDigestVector: - """Derive the secret start of one Winternitz hash chain. +class PRFKey(BaseBytes): + """ + The PRF master secret key. - Args: - config: Active XMSS configuration. - key: Master PRF key. - epoch: Slot identifier for this one-time signature instance. - chain_index: Position of the chain within the one-time signature. + High-entropy byte string acting as the single root secret. - Returns: - A hash digest used as the chain start. + Every one-time signing key is deterministically derived from this seed. """ - # Layout: - # - # domain_sep || 0x00 || key || epoch (4 bytes) || chain_index (8 bytes) - # - # The 0x00 byte separates chain-start derivation from randomness derivation. - input_data = ( - PRF_DOMAIN_SEP - + PRF_DOMAIN_SEP_DOMAIN_ELEMENT - + key - + epoch.to_bytes(4, "big") - + chain_index.to_bytes(8, "big") - ) - - # Pull enough SHAKE128 bytes to fill HASH_LEN_FE field elements. - num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE - prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) - return HashDigestVector( - data=[ - Fp(value=int.from_bytes(bytes(chunk), "big")) - for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) - ] - ) - - -def prf_get_randomness( - config: XmssConfig, - key: PRFKey, - epoch: Uint64, - message: Bytes32, - counter: Uint64, -) -> Randomness: - """Derive deterministic randomness for a signing attempt. - - Same construction as the chain-start derivation, with a different subdomain - tag. - Including the message and counter makes signing reproducible without - breaking security: signing twice with the same key, epoch, and message - always produces the same randomness. - - Args: - config: Active XMSS configuration. - key: Master PRF key. - epoch: Slot identifier for this signature. - message: Full message being signed. - counter: Attempt number, incremented when a previous attempt aborted. - - Returns: - Randomness used to encode the message into a valid codeword. - """ - # Layout: - # - # domain_sep || 0x01 || key || epoch || message || counter - input_data = ( - PRF_DOMAIN_SEP - + PRF_DOMAIN_SEP_RANDOMNESS - + key - + epoch.to_bytes(4, "big") - + message - + counter.to_bytes(8, "big") - ) - - num_bytes_to_read = PRF_BYTES_PER_FE * config.RAND_LEN_FE - prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) - return Randomness( - data=[ - Fp(value=int.from_bytes(bytes(chunk), "big")) - for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) - ] - ) + + LENGTH = PRF_KEY_LENGTH + + @classmethod + def generate(cls) -> Self: + """Draw a fresh master key from the operating system entropy pool.""" + return cls(os.urandom(PRF_KEY_LENGTH)) + + def derive_chain_start( + self, config: XmssConfig, epoch: Uint64, chain_index: Uint64 + ) -> HashDigestVector: + """ + Derive the secret start of one Winternitz hash chain. + + # Overview + + Each slot signs with many independent hash chains. + A chain begins at a secret value. + Its public counterpart is that value hashed all the way up to the chain top. + Recreating the secret start from the seed means it never has to be stored. + + Args: + config: Active XMSS configuration. + epoch: Slot identifier for this one-time signature instance. + chain_index: Position of the chain within the one-time signature. + + Returns: + The secret digest at the bottom of the chain. + """ + # Layout: + # + # domain_sep || 0x00 || key || epoch (4 bytes) || chain_index (8 bytes) + # + # The 0x00 byte separates chain-start derivation from randomness derivation. + input_data = ( + PRF_DOMAIN_SEP + + PRF_DOMAIN_SEP_DOMAIN_ELEMENT + + self + + epoch.to_bytes(4, "big") + + chain_index.to_bytes(8, "big") + ) + + # Pull enough SHAKE128 bytes to fill one digest of field elements. + num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE + prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) + return HashDigestVector( + data=[ + Fp(value=int.from_bytes(bytes(chunk), "big")) + for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) + ] + ) + + def derive_randomness( + self, + config: XmssConfig, + epoch: Uint64, + message: Bytes32, + counter: Uint64, + ) -> Randomness: + """ + Derive deterministic randomness for one signing attempt. + + # Overview + + Signing maps the message onto a codeword whose digits must add up to a fixed target. + - Each attempt folds in fresh randomness. + - It retries with a higher counter until the digit sum lands on the target. + + Deriving that randomness from the seed makes the search reproducible. + + # Reproducibility + + A synchronized scheme signs each slot at most once. + Signing one slot twice is treated as misbehavior. + A deterministic attempt order means one slot and message always yield the same signature. + + Args: + config: Active XMSS configuration. + epoch: Slot identifier for this signature. + message: Full message being signed. + counter: Attempt number, incremented when a previous attempt aborted. + + Returns: + Randomness used to encode the message into a valid codeword. + """ + # Layout: + # + # domain_sep || 0x01 || key || epoch || message || counter + input_data = ( + PRF_DOMAIN_SEP + + PRF_DOMAIN_SEP_RANDOMNESS + + self + + epoch.to_bytes(4, "big") + + message + + counter.to_bytes(8, "big") + ) + + num_bytes_to_read = PRF_BYTES_PER_FE * config.RAND_LEN_FE + prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) + return Randomness( + data=[ + Fp(value=int.from_bytes(bytes(chunk), "big")) + for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) + ] + ) diff --git a/src/lean_spec/subspecs/xmss/types.py b/src/lean_spec/subspecs/xmss/types.py index ed88836d6..b0c4567fd 100644 --- a/src/lean_spec/subspecs/xmss/types.py +++ b/src/lean_spec/subspecs/xmss/types.py @@ -5,10 +5,9 @@ from lean_spec.subspecs.koalabear import Fp from ...types import Uint64 -from ...types.byte_arrays import BaseBytes from ...types.collections import SSZList, SSZVector from ...types.container import Container -from .constants import PRF_KEY_LENGTH, TARGET_CONFIG +from .constants import TARGET_CONFIG class TreeTweak(NamedTuple): @@ -41,35 +40,25 @@ class ChainTweak(NamedTuple): """ -class PRFKey(BaseBytes): - """The PRF master secret key. +NODE_LIST_LIMIT: Final = 2 * TARGET_CONFIG.LEAVES_PER_BOTTOM_TREE +""" +Maximum number of nodes a sparse Merkle tree layer can hold. - High-entropy byte string acting as the single root secret. - Every one-time signing key is deterministically derived from this seed. - """ - - LENGTH = PRF_KEY_LENGTH - - -HASH_DIGEST_LENGTH: Final = TARGET_CONFIG.HASH_LEN_FE -"""Length of one hash digest in field elements. - -Corresponds to the Poseidon1 output length used in the XMSS scheme.""" - -# Why: a bottom tree spans 2^(LOG_LIFETIME/2) leaves. -# Padding may add up to two extra siblings. -# Doubling that bound leaves room for future-proof layouts without resizing. -NODE_LIST_LIMIT: Final = 1 << (TARGET_CONFIG.LOG_LIFETIME // 2 + 1) -"""Maximum number of nodes that can be stored in a sparse Merkle tree layer.""" +- The widest layer is a bottom tree's leaf row, the square root of the lifetime in leaves. +- Padding adds at most one sibling at each end. +- Twice the leaf count is a generous cap that absorbs the padding with room to spare. +""" class HashDigestVector(SSZVector[Fp]): - """A single hash digest as a fixed-size vector of field elements. + """ + A single hash digest as a fixed-size vector of field elements. The fixed size lets SSZ pack these back-to-back without per-element offsets. """ - LENGTH = HASH_DIGEST_LENGTH + LENGTH = TARGET_CONFIG.HASH_LEN_FE + """One Poseidon1 digest, measured in field elements.""" class HashDigestList(SSZList[HashDigestVector]): @@ -81,31 +70,34 @@ class HashDigestList(SSZList[HashDigestVector]): class Parameter(SSZVector[Fp]): """The public parameter P. - Unique, randomly generated value associated with a single key pair. - Mixed into every hash to personalize the function and block cross-key attacks. - Public knowledge. + - Unique, randomly generated value associated with a single key pair. + - Mixed into every hash to personalize the function and block cross-key attacks. + - Public knowledge. """ LENGTH = TARGET_CONFIG.PARAMETER_LEN class Randomness(SSZVector[Fp]): - """The randomness rho used during signing. + """ + Fresh randomness mixed into the message hash during signing. - Variable input to the message hash so the signer can resample until a - valid codeword is found. - Included in the final signature so the verifier reproduces the hash. + - Signing rehashes the message with new randomness on each attempt. + - Retries continue until the resulting codeword hits the target sum. + - The chosen randomness travels in the signature so the verifier recomputes the same codeword. """ LENGTH = TARGET_CONFIG.RAND_LEN_FE class HashTreeOpening(Container): - """A Merkle authentication path. + """ + A Merkle authentication path proving one leaf sits under the root. - Contains the minimal proof connecting a specific leaf to the Merkle root. - Holds every sibling node along the path from the leaf to the tree top. + - The path lists the sibling hashes met while climbing from the leaf up to the root. + - A verifier rehashes the leaf upward with these siblings. + - The reconstructed root must equal the trusted root. """ siblings: HashDigestList - """SSZ-compliant list of sibling hashes, from bottom to top.""" + """Sibling hashes, ordered from the leaf upward to the root.""" diff --git a/tests/consensus/lstar/ssz/test_xmss_containers.py b/tests/consensus/lstar/ssz/test_xmss_containers.py index a615be188..46e2b4f00 100644 --- a/tests/consensus/lstar/ssz/test_xmss_containers.py +++ b/tests/consensus/lstar/ssz/test_xmss_containers.py @@ -10,9 +10,9 @@ TypeOneMultiSignature, TypeTwoMultiSignature, ) +from lean_spec.subspecs.xmss.constants import TARGET_CONFIG from lean_spec.subspecs.xmss.merkle import HashTreeLayer from lean_spec.subspecs.xmss.types import ( - HASH_DIGEST_LENGTH, HashDigestList, HashDigestVector, HashTreeOpening, @@ -36,7 +36,7 @@ def _zero_hash_digest_vector() -> HashDigestVector: """Build a hash digest vector with all field elements set to zero.""" - return HashDigestVector(data=[Fp(0) for _ in range(HASH_DIGEST_LENGTH)]) + return HashDigestVector(data=[Fp(0) for _ in range(TARGET_CONFIG.HASH_LEN_FE)]) def _zero_parameter() -> Parameter: @@ -120,7 +120,7 @@ def test_public_key_typical(ssz: SSZTestFiller) -> None: ssz( type_name="PublicKey", value=PublicKey( - root=HashDigestVector(data=[Fp(i + 1) for i in range(HASH_DIGEST_LENGTH)]), + root=HashDigestVector(data=[Fp(i + 1) for i in range(TARGET_CONFIG.HASH_LEN_FE)]), parameter=Parameter(data=[Fp(100 + i) for i in range(Parameter.LENGTH)]), ), ) @@ -144,7 +144,9 @@ def test_hash_tree_opening_typical(ssz: SSZTestFiller) -> None: value=HashTreeOpening( siblings=HashDigestList( data=[ - HashDigestVector(data=[Fp(i + j * 10) for i in range(HASH_DIGEST_LENGTH)]) + HashDigestVector( + data=[Fp(i + j * 10) for i in range(TARGET_CONFIG.HASH_LEN_FE)] + ) for j in range(3) ] ) @@ -174,7 +176,7 @@ def test_hash_tree_layer_typical(ssz: SSZTestFiller) -> None: start_index=Uint64(42), nodes=HashDigestList( data=[ - HashDigestVector(data=[Fp(i + j * 7) for i in range(HASH_DIGEST_LENGTH)]) + HashDigestVector(data=[Fp(i + j * 7) for i in range(TARGET_CONFIG.HASH_LEN_FE)]) for j in range(2) ] ), diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py index efa9fa3bc..70b5b9d17 100644 --- a/tests/lean_spec/subspecs/xmss/test_prf.py +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -4,8 +4,7 @@ PRF_KEY_LENGTH, TEST_CONFIG, ) -from lean_spec.subspecs.xmss.prf import prf_apply, prf_key_gen -from lean_spec.subspecs.xmss.types import PRFKey +from lean_spec.subspecs.xmss.prf import PRFKey from lean_spec.types import Uint64 @@ -17,14 +16,14 @@ def test_key_gen_is_random() -> None: This test mirrors the logic from the reference Rust implementation. """ # Check that the key has the correct length. - key = prf_key_gen() + key = PRFKey.generate() assert len(key) == PRF_KEY_LENGTH # Generate multiple keys and ensure they are not all identical. # # This is a basic check to ensure we are getting fresh randomness. num_trials = 10 - keys = {prf_key_gen() for _ in range(num_trials)} + keys = {PRFKey.generate() for _ in range(num_trials)} assert len(keys) == num_trials # Check that the keys are not filled with a single repeated byte. @@ -33,7 +32,7 @@ def test_key_gen_is_random() -> None: # such a key, so this is a good health check. all_same_count = 0 for _ in range(num_trials): - key = prf_key_gen() + key = PRFKey.generate() # A set will have size 1 if all elements are the same. if len(set(key)) == 1: all_same_count += 1 @@ -53,20 +52,20 @@ def test_apply_is_sensitive_to_inputs() -> None: key1 = PRFKey(b"\x11" * PRF_KEY_LENGTH) epoch1 = Uint64(10) chain_index1 = Uint64(20) - baseline_output = prf_apply(config, key1, epoch1, chain_index1) + baseline_output = key1.derive_chain_start(config, epoch1, chain_index1) assert len(baseline_output) == config.HASH_LEN_FE # Test sensitivity to the key. key2 = PRFKey(b"\x22" * PRF_KEY_LENGTH) - output_key_changed = prf_apply(config, key2, epoch1, chain_index1) + output_key_changed = key2.derive_chain_start(config, epoch1, chain_index1) assert baseline_output != output_key_changed # Test sensitivity to the epoch. epoch2 = Uint64(11) - output_epoch_changed = prf_apply(config, key1, epoch2, chain_index1) + output_epoch_changed = key1.derive_chain_start(config, epoch2, chain_index1) assert baseline_output != output_epoch_changed # Test sensitivity to the chain_index. chain_index2 = Uint64(21) - output_index_changed = prf_apply(config, key1, epoch1, chain_index2) + output_index_changed = key1.derive_chain_start(config, epoch1, chain_index2) assert baseline_output != output_index_changed diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py index dc790bf08..3be470bf1 100644 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ b/tests/lean_spec/subspecs/xmss/test_utils.py @@ -11,7 +11,7 @@ from lean_spec.subspecs.xmss.interface import _expand_activation_time from lean_spec.subspecs.xmss.merkle import HashSubTree from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON -from lean_spec.subspecs.xmss.prf import prf_key_gen +from lean_spec.subspecs.xmss.prf import PRFKey from lean_spec.subspecs.xmss.types import Parameter from lean_spec.types import Uint64 @@ -111,7 +111,7 @@ def test_hash_subtree_from_prf_key() -> None: config = TEST_CONFIG # Generate a PRF key - prf_key = prf_key_gen() + prf_key = PRFKey.generate() # Generate a random parameter parameter = Parameter( @@ -145,7 +145,7 @@ def test_hash_subtree_from_prf_key() -> None: def test_hash_subtree_from_prf_key_deterministic() -> None: """Tests that HashSubTree.from_prf_key is deterministic.""" config = TEST_CONFIG - prf_key = prf_key_gen() + prf_key = PRFKey.generate() parameter = Parameter( data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] ) @@ -174,7 +174,7 @@ def test_hash_subtree_from_prf_key_deterministic() -> None: def test_hash_subtree_from_prf_key_different_indices() -> None: """Tests that different bottom tree indices produce different trees.""" config = TEST_CONFIG - prf_key = prf_key_gen() + prf_key = PRFKey.generate() parameter = Parameter( data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] ) From b792c4f552fa8b79c8ed63406741b540146a4712 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 21:31:50 +0200 Subject: [PATCH 8/9] test(xmss): mirror source modules and complete coverage Reorganize the XMSS unit tests so each source module has one matching test file, redistributing the orphan files into the module they exercise. Drive every module to full line and branch coverage apart from three intentionally unreachable arcs, deduplicate overlapping cases, split bundled assertions into single-behavior tests, parametrize scenario families across sizes and edges, and match the complete text of every error message the spec itself raises. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../subspecs/xmss/test_aggregation.py | 69 ++++ .../lean_spec/subspecs/xmss/test_constants.py | 285 +++++++++++++ .../subspecs/xmss/test_containers.py | 113 ++++- .../lean_spec/subspecs/xmss/test_encoding.py | 147 +++++++ tests/lean_spec/subspecs/xmss/test_field.py | 88 ++++ .../lean_spec/subspecs/xmss/test_interface.py | 79 +++- tests/lean_spec/subspecs/xmss/test_merkle.py | 387 ++++++++++++++++++ .../subspecs/xmss/test_merkle_tree.py | 197 --------- .../subspecs/xmss/test_message_hash.py | 122 ------ .../lean_spec/subspecs/xmss/test_poseidon.py | 184 +++++++++ .../subspecs/xmss/test_security_levels.py | 307 -------------- .../subspecs/xmss/test_ssz_serialization.py | 147 ------- tests/lean_spec/subspecs/xmss/test_types.py | 92 +++++ tests/lean_spec/subspecs/xmss/test_utils.py | 200 --------- 14 files changed, 1438 insertions(+), 979 deletions(-) create mode 100644 tests/lean_spec/subspecs/xmss/test_constants.py create mode 100644 tests/lean_spec/subspecs/xmss/test_encoding.py create mode 100644 tests/lean_spec/subspecs/xmss/test_field.py create mode 100644 tests/lean_spec/subspecs/xmss/test_merkle.py delete mode 100644 tests/lean_spec/subspecs/xmss/test_merkle_tree.py delete mode 100644 tests/lean_spec/subspecs/xmss/test_message_hash.py create mode 100644 tests/lean_spec/subspecs/xmss/test_poseidon.py delete mode 100644 tests/lean_spec/subspecs/xmss/test_security_levels.py delete mode 100644 tests/lean_spec/subspecs/xmss/test_ssz_serialization.py create mode 100644 tests/lean_spec/subspecs/xmss/test_types.py delete mode 100644 tests/lean_spec/subspecs/xmss/test_utils.py diff --git a/tests/lean_spec/subspecs/xmss/test_aggregation.py b/tests/lean_spec/subspecs/xmss/test_aggregation.py index 1e345d0c2..b23c2712e 100644 --- a/tests/lean_spec/subspecs/xmss/test_aggregation.py +++ b/tests/lean_spec/subspecs/xmss/test_aggregation.py @@ -242,6 +242,57 @@ def test_aggregate_mixed_children_and_raw_multiple(key_manager: XmssKeyManager) ) +def test_type_one_verify_rejects_pubkey_count_mismatch(key_manager: XmssKeyManager) -> None: + """Type-1 verification refuses a pubkey set that does not match the bitfield.""" + source = Checkpoint(root=make_bytes32(160), slot=Slot(0)) + att_args = (Slot(2), 161, 162, source) + vids = [ValidatorIndex(0), ValidatorIndex(1)] + + proof = _sign_and_aggregate(key_manager, vids, att_args) + # The bitfield names two validators but only one key is supplied. + only_one = [key_manager[ValidatorIndex(0)].attestation_keypair.public_key] + + with pytest.raises( + AggregationError, match="Type-1 verify expected 2 pubkeys for participants, got 1" + ): + proof.verify(public_keys=only_one, message=make_bytes32(161), slot=att_args[0]) + + +def test_type_two_split_by_msg_rejected_under_test_prover(key_manager: XmssKeyManager) -> None: + """Splitting a merged proof aborts under the reduced test-config prover. + + The split branch is functional only under the production prover. + The test-config build aborts it with an in-circuit assertion. + Exercising it here drives the serialization and error-translation path. + """ + source = Checkpoint(root=make_bytes32(600), slot=Slot(0)) + att_args_a = (Slot(11), 601, 602, source) + att_args_b = (Slot(11), 603, 604, source) + att_data_a = make_attestation_data_simple( + att_args_a[0], make_bytes32(att_args_a[1]), make_bytes32(att_args_a[2]), att_args_a[3] + ) + + vids_a = [ValidatorIndex(0), ValidatorIndex(1)] + vids_b = [ValidatorIndex(2), ValidatorIndex(3)] + part_a = _sign_and_aggregate(key_manager, vids_a, att_args_a) + part_b = _sign_and_aggregate(key_manager, vids_b, att_args_b) + + pubkeys_a = [key_manager[vid].attestation_keypair.public_key for vid in vids_a] + pubkeys_b = [key_manager[vid].attestation_keypair.public_key for vid in vids_b] + + merged = TypeTwoMultiSignature.aggregate( + parts=[part_a, part_b], + public_keys_per_part=[pubkeys_a, pubkeys_b], + ) + + with pytest.raises(AggregationError, match="Type-2 split failed"): + merged.split_by_msg( + message=hash_tree_root(att_data_a), + public_keys_per_message=[pubkeys_a, pubkeys_b], + participants=part_a.participants, + ) + + def test_aggregate_wrong_message_fails_verification(key_manager: XmssKeyManager) -> None: """Verification fails when the caller passes a message that does not match the proof.""" source = Checkpoint(root=make_bytes32(120), slot=Slot(0)) @@ -349,6 +400,24 @@ def test_type_two_aggregate_rejects_mismatched_pubkey_layout( ) +def test_type_two_aggregate_propagates_prover_error(key_manager: XmssKeyManager) -> None: + """A corrupted component proof makes the merge prover reject the inputs.""" + source = Checkpoint(root=make_bytes32(210), slot=Slot(0)) + att_args = (Slot(8), 211, 212, source) + vids = [ValidatorIndex(0), ValidatorIndex(1)] + + part = _sign_and_aggregate(key_manager, vids, att_args) + pubkeys = [key_manager[vid].attestation_keypair.public_key for vid in vids] + + corrupted_bytes = bytearray(part.proof.data) + corrupted_bytes[10] ^= 0xFF + corrupted_bytes[20] ^= 0xFF + corrupted = part.model_copy(update={"proof": ByteList512KiB(data=bytes(corrupted_bytes))}) + + with pytest.raises(AggregationError, match="merge_many_type_1 failed"): + TypeTwoMultiSignature.aggregate(parts=[corrupted], public_keys_per_part=[pubkeys]) + + def test_type_two_verify_round_trip(key_manager: XmssKeyManager) -> None: """A Type-2 merge of two distinct-message Type-1 proofs round-trips through verify.""" source = Checkpoint(root=make_bytes32(300), slot=Slot(0)) diff --git a/tests/lean_spec/subspecs/xmss/test_constants.py b/tests/lean_spec/subspecs/xmss/test_constants.py new file mode 100644 index 000000000..3b504c2d6 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_constants.py @@ -0,0 +1,285 @@ +""" +Tests for the XMSS cryptographic constants, configuration presets, and security margins. + +The security-margin tests validate that the production parameter choices achieve +adequate classical and quantum security. + +Based on: + +- [DKKW25c] "Hash-Based Multi-Signatures for Post-Quantum Ethereum" + (https://eprint.iacr.org/2025/055.pdf) +- [HKKTW26] "Aborting Random Oracles" + (https://eprint.iacr.org/2026/016) + +The security analysis follows the framework of [DKKW25c] Section 6. +""" + +import math + +import pytest + +from lean_spec.subspecs.koalabear import P_BYTES, P +from lean_spec.subspecs.xmss.constants import ( + PROD_CONFIG, + TARGET_CONFIG, + TEST_CONFIG, + XmssConfig, +) +from lean_spec.types import Uint64 +from lean_spec.types.ssz_base import BYTES_PER_LENGTH_OFFSET + + +def _valid_config_kwargs() -> dict[str, int]: + """Return a copy of the production configuration fields as plain kwargs.""" + return PROD_CONFIG.model_dump() + + +def test_decomposition_validator_rejects_bad_product() -> None: + """A configuration whose product does not equal the prime minus one is rejected.""" + kwargs = _valid_config_kwargs() + kwargs["Q"] = 128 + with pytest.raises(ValueError, match=f"Q \\* BASE\\^Z must equal P-1={P - 1}"): + XmssConfig(**kwargs) + + +def test_decomposition_validator_accepts_valid_product() -> None: + """A configuration whose product equals the prime minus one validates.""" + config = XmssConfig(**_valid_config_kwargs()) + assert config.Q * config.BASE**config.Z == P - 1 + + +def test_target_config_is_test_config_under_test_env() -> None: + """The active configuration under the test environment is the test preset.""" + assert TARGET_CONFIG is TEST_CONFIG + + +def test_lifetime_is_two_to_the_log_lifetime() -> None: + """The lifetime is two raised to the configured base-two logarithm.""" + assert TEST_CONFIG.LIFETIME == Uint64(1 << TEST_CONFIG.LOG_LIFETIME) + + +def test_leaves_per_bottom_tree_is_square_root_of_lifetime() -> None: + """One bottom tree covers the square root of the lifetime in leaves.""" + assert TEST_CONFIG.LEAVES_PER_BOTTOM_TREE == 1 << (TEST_CONFIG.LOG_LIFETIME // 2) + + +@pytest.mark.parametrize( + "dimension, z, expected", + [ + pytest.param(4, 8, 1, id="dimension below one field element rounds up to one"), + pytest.param(46, 8, 6, id="production dimension needs six field elements"), + pytest.param(16, 8, 2, id="two full field elements exactly cover sixteen digits"), + ], +) +def test_mh_hash_len_rounds_up(dimension: int, z: int, expected: int) -> None: + """The aborting-decode output length is the dimension divided by digits, rounded up.""" + kwargs = _valid_config_kwargs() + kwargs["DIMENSION"] = dimension + kwargs["Z"] = z + assert XmssConfig(**kwargs).MH_HASH_LEN_FE == expected + + +def test_signature_len_bytes_matches_layout() -> None: + """The advertised signature length equals the sum of its SSZ-encoded fields.""" + config = TEST_CONFIG + path = config.LOG_LIFETIME * config.HASH_LEN_FE * P_BYTES + rho = config.RAND_LEN_FE * P_BYTES + hashes = config.DIMENSION * config.HASH_LEN_FE * P_BYTES + expected = path + rho + hashes + 3 * BYTES_PER_LENGTH_OFFSET + assert config.SIGNATURE_LEN_BYTES == expected + + +@pytest.mark.parametrize( + "param_name, value", + [ + ("DIMENSION", 46), + ("BASE", 8), + ("Z", 8), + ("Q", 127), + ("TARGET_SUM", 200), + ("LOG_LIFETIME", 32), + ("PARAMETER_LEN", 5), + ("TWEAK_LEN_FE", 2), + ("MSG_LEN_FE", 9), + ("RAND_LEN_FE", 7), + ("HASH_LEN_FE", 8), + ("CAPACITY", 9), + ], +) +def test_prod_config_matches_reference(param_name: str, value: int) -> None: + """Production parameters must match the canonical Rust implementation.""" + assert getattr(PROD_CONFIG, param_name) == value + + +def _calculate_layer_size(w: int, v: int, d: int) -> int: + """Count a hypercube layer's size using inclusion-exclusion. + + Counts integer solutions to x_1 + ... + x_v = k with 0 <= x_i <= w-1, + where k = v*(w-1) - d. + """ + coord_sum = v * (w - 1) - d + return sum( + ((-1) ** s) * math.comb(v, s) * math.comb(coord_sum - s * w + v - 1, v - 1) + for s in range(coord_sum // w + 1) + ) + + +def _compute_security_levels(config: XmssConfig) -> dict[str, float]: + """Compute classical and quantum security levels for a configuration. + + Returns a dict with keys: + + - k_classical: effective classical security in bits + - k_quantum: effective quantum security in bits + - expected_attempts: expected signing attempts per message + - signing_failure_log2: log2 of the probability that all attempts fail + """ + v = config.DIMENSION + w_bits = int(math.log2(config.BASE)) + base = config.BASE + + # Each field element contributes floor(log2(P)) = 31 bits. + fe_bits = 31 + bits_digest = config.HASH_LEN_FE * fe_bits + bits_param = config.PARAMETER_LEN * fe_bits + bits_rand = config.RAND_LEN_FE * fe_bits + + # Raw message hash output is v chunks of w bits each. + bits_msg = v * w_bits + + # Abort correction from [HKKTW26] Corollary 1, Remark 14. + # + # Each field element aborts iff it equals the prime minus one. + # The non-abort probability per element is (P - 1) / P. + # Over ell field elements the total non-abort probability is that ratio to the ell. + wz = base**config.Z + q = config.Q + ell = math.ceil(v / config.Z) + + non_abort_total = ((q * wz) / P) ** ell + abort_correction_bits = -math.log2(non_abort_total) + + bits_msg_eff = bits_msg + abort_correction_bits + + log5 = math.log2(5) + log12 = math.log2(12) + log3 = math.log2(3) + log_lifetime = math.log2(config.LIFETIME) + logv = math.log2(v) + log_max_tries = math.log2(config.MAX_TRIES) + logqs = math.log2(config.LIFETIME) + + # Classical security is the minimum over four attack surfaces. + k_classical = min( + bits_digest - log5 - 2 * w_bits - log_lifetime - logv, + bits_param - log5 - 3, + bits_msg_eff - log5 - 1, + bits_rand - log5 - logqs - log_max_tries - 1, + ) + + # Quantum security is the minimum over four attack surfaces. + k_quantum = min( + bits_digest / 2 - log5 - 2 * w_bits - log_lifetime - logv - log12, + (bits_param - 5) / 2 - log5 - 2, + (bits_msg_eff - 3) / 2 - log5 - 1, + (bits_rand - logqs) / 2 - log5 - log3 - log_max_tries, + ) + + # Expected signing attempts for target-sum encoding. + d = v * (base - 1) - config.TARGET_SUM + layer_size = _calculate_layer_size(base, v, d) + layer_prob = layer_size / base**v + success_prob = non_abort_total * layer_prob + expected_attempts = 1 / success_prob + + signing_failure_log2 = config.MAX_TRIES * math.log2(1 - success_prob) + + return { + "k_classical": k_classical, + "k_quantum": k_quantum, + "expected_attempts": expected_attempts, + "signing_failure_log2": signing_failure_log2, + } + + +def test_prod_classical_security() -> None: + """Production parameters achieve at least 128-bit classical security.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["k_classical"] >= 128 + + +def test_prod_quantum_security() -> None: + """Production parameters achieve at least 64-bit quantum security.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["k_quantum"] >= 64 + + +def test_prod_expected_signing_attempts_are_bounded() -> None: + """Signing succeeds within a manageable number of attempts on average.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["expected_attempts"] < 1000 + + +def test_prod_signing_failure_is_negligible() -> None: + """The probability of exhausting every attempt is below two to the minus 128.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["signing_failure_log2"] < -128 + + +def test_prod_abort_probability_is_negligible() -> None: + """The aborting decode rejection probability is below two to the minus 28.""" + config = PROD_CONFIG + ell = math.ceil(config.DIMENSION / config.Z) + non_abort_per_fe = (config.Q * config.BASE**config.Z) / P + total_non_abort = non_abort_per_fe**ell + assert 1 - total_non_abort < 2**-28 + + +def test_prod_base_is_power_of_two() -> None: + """The alphabet size is a power of two so digits map cleanly onto bits.""" + w_bits = int(math.log2(PROD_CONFIG.BASE)) + assert PROD_CONFIG.BASE == 2**w_bits + + +def test_prod_digit_width_divides_twenty_four() -> None: + """The digit width divides twenty-four so rejection sampling works for KoalaBear.""" + w_bits = int(math.log2(PROD_CONFIG.BASE)) + assert 24 % w_bits == 0 + + +def test_prod_z_equals_twenty_four_over_digit_width() -> None: + """The digit count equals twenty-four divided by the digit width for the optimal decode.""" + w_bits = int(math.log2(PROD_CONFIG.BASE)) + assert PROD_CONFIG.Z == 24 // w_bits + + +def test_prod_mh_hash_len_covers_dimension() -> None: + """The aborting-decode output produces at least one digit per hash chain.""" + config = PROD_CONFIG + assert config.MH_HASH_LEN_FE * config.Z >= config.DIMENSION + + +def test_prod_binding_constraint_is_message_hash() -> None: + """The tightest classical bound is the message hash, matching the design intent.""" + config = PROD_CONFIG + v = config.DIMENSION + w_bits = int(math.log2(config.BASE)) + fe_bits = 31 + + bits_digest = config.HASH_LEN_FE * fe_bits + bits_param = config.PARAMETER_LEN * fe_bits + bits_rand = config.RAND_LEN_FE * fe_bits + bits_msg = v * w_bits + + log5 = math.log2(5) + log_lifetime = math.log2(config.LIFETIME) + logv = math.log2(v) + log_max_tries = math.log2(config.MAX_TRIES) + + classical_bounds = [ + bits_digest - log5 - 2 * w_bits - log_lifetime - logv, + bits_param - log5 - 3, + bits_msg - log5 - 1, + bits_rand - log5 - log_lifetime - log_max_tries - 1, + ] + assert classical_bounds.index(min(classical_bounds)) == 2 diff --git a/tests/lean_spec/subspecs/xmss/test_containers.py b/tests/lean_spec/subspecs/xmss/test_containers.py index c4c981dfd..3934b2a1c 100644 --- a/tests/lean_spec/subspecs/xmss/test_containers.py +++ b/tests/lean_spec/subspecs/xmss/test_containers.py @@ -1,4 +1,4 @@ -"""Behaviour tests for ValidatorKeyPair.""" +"""Behaviour tests for the XMSS containers.""" import json @@ -6,8 +6,17 @@ from consensus_testing.keys import XmssKeyManager from pydantic import ValidationError -from lean_spec.subspecs.xmss.containers import KeyPair, ValidatorKeyPair -from lean_spec.types import ValidatorIndex +from lean_spec.subspecs.koalabear.field import P_BYTES +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.containers import ( + KeyPair, + PublicKey, + SecretKey, + Signature, + ValidatorKeyPair, +) +from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME +from lean_spec.types import Bytes32, Slot, Uint64, ValidatorIndex @pytest.fixture(scope="module") @@ -213,3 +222,101 @@ def test_keypair_frozen(keypair_a: KeyPair) -> None: """KeyPair fields cannot be reassigned (StrictBaseModel is frozen).""" with pytest.raises(ValidationError): keypair_a.public_key = keypair_a.public_key + + +def test_keypair_decodes_public_and_secret_hex(keypair_a: KeyPair) -> None: + """A key pair validates from hex strings for both halves.""" + decoded = KeyPair.model_validate( + { + "public_key": keypair_a.public_key.encode_bytes().hex(), + "secret_key": keypair_a.secret_key.encode_bytes().hex(), + } + ) + assert decoded == keypair_a + + +def test_keypair_rejects_invalid_public_key_hex(keypair_a: KeyPair) -> None: + """A malformed public-key hex string surfaces as a validation error.""" + with pytest.raises(ValidationError, match="invalid public key hex"): + KeyPair.model_validate( + { + "public_key": "deadbeef", + "secret_key": keypair_a.secret_key.encode_bytes().hex(), + } + ) + + +def test_keypair_rejects_invalid_secret_key_hex(keypair_a: KeyPair) -> None: + """A malformed secret-key hex string surfaces as a validation error.""" + with pytest.raises(ValidationError, match="invalid secret key hex"): + KeyPair.model_validate( + { + "public_key": keypair_a.public_key.encode_bytes().hex(), + "secret_key": "deadbeef", + } + ) + + +@pytest.fixture(scope="module") +def signed_key_pair() -> KeyPair: + """A key pair generated directly from the test scheme.""" + return TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)) + + +@pytest.fixture(scope="module") +def sample_signature(signed_key_pair: KeyPair) -> Signature: + """A signature over a fixed message at slot zero.""" + return TEST_SIGNATURE_SCHEME.sign( + signed_key_pair.secret_key, Slot(0), Bytes32(bytes([42] * 32)) + ) + + +def test_public_key_ssz_roundtrip(signed_key_pair: KeyPair) -> None: + """A public key encodes and decodes back to an equal value.""" + public_key = signed_key_pair.public_key + assert PublicKey.decode_bytes(public_key.encode_bytes()) == public_key + + +def test_public_key_encoded_size_matches_layout(signed_key_pair: KeyPair) -> None: + """The encoded public key is the digest plus parameter packed into field bytes.""" + encoded = signed_key_pair.public_key.encode_bytes() + expected = (TEST_CONFIG.HASH_LEN_FE + TEST_CONFIG.PARAMETER_LEN) * P_BYTES + assert len(encoded) == expected + + +def test_secret_key_ssz_roundtrip(signed_key_pair: KeyPair) -> None: + """A secret key encodes and decodes back to an equal value.""" + secret_key = signed_key_pair.secret_key + assert SecretKey.decode_bytes(secret_key.encode_bytes()) == secret_key + + +def test_signature_is_fixed_size() -> None: + """A signature reports as fixed-size on the wire.""" + assert Signature.is_fixed_size() is True + + +def test_signature_byte_length_matches_config() -> None: + """The signature byte length matches the configured fixed length.""" + assert Signature.get_byte_length() == TEST_CONFIG.SIGNATURE_LEN_BYTES + + +def test_signature_ssz_roundtrip(sample_signature: Signature) -> None: + """A signature encodes and decodes back to an equal value.""" + assert Signature.decode_bytes(sample_signature.encode_bytes()) == sample_signature + + +def test_signature_encoded_size_matches_config(sample_signature: Signature) -> None: + """The encoded signature length matches the advertised fixed length.""" + assert len(sample_signature.encode_bytes()) == TEST_CONFIG.SIGNATURE_LEN_BYTES + + +def test_signature_json_is_prefixed_hex(sample_signature: Signature) -> None: + """The JSON form is a hex string prefixed with the byte marker.""" + dumped = json.loads(sample_signature.model_dump_json()) + assert dumped == "0x" + sample_signature.encode_bytes().hex() + + +def test_signature_decodes_from_json(sample_signature: Signature) -> None: + """A signature decodes back from its JSON hex form.""" + encoded = "0x" + sample_signature.encode_bytes().hex() + assert Signature.from_hex(encoded) == sample_signature diff --git a/tests/lean_spec/subspecs/xmss/test_encoding.py b/tests/lean_spec/subspecs/xmss/test_encoding.py new file mode 100644 index 000000000..bd45df48e --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_encoding.py @@ -0,0 +1,147 @@ +"""Tests for the message-to-codeword encoding pipeline.""" + +import pytest + +from lean_spec.subspecs.koalabear import Fp, P +from lean_spec.subspecs.xmss import encoding +from lean_spec.subspecs.xmss.constants import TEST_CONFIG, TWEAK_PREFIX_MESSAGE +from lean_spec.subspecs.xmss.encoding import ( + aborting_decode, + encode_epoch, + encode_message, + message_hash, + target_sum_encode, +) +from lean_spec.subspecs.xmss.field import int_to_base_p, random_field_elements +from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON +from lean_spec.subspecs.xmss.types import Parameter, Randomness +from lean_spec.types import Bytes32, Uint64 + + +def _parameter() -> Parameter: + """Return a fixed public parameter for encoding tests.""" + return Parameter(data=[Fp(value=1)] * TEST_CONFIG.PARAMETER_LEN) + + +def test_encode_message_zero_is_all_zero_limbs() -> None: + """An all-zero message encodes to all-zero field elements.""" + encoded = encode_message(TEST_CONFIG, Bytes32(b"\x00" * 32)) + assert encoded == [Fp(value=0)] * TEST_CONFIG.MSG_LEN_FE + + +def test_encode_message_reads_little_endian() -> None: + """A maximal message encodes to its little-endian base-P decomposition.""" + message = Bytes32(b"\xff" * 32) + acc = int.from_bytes(message, "little") + assert encode_message(TEST_CONFIG, message) == int_to_base_p(acc, TEST_CONFIG.MSG_LEN_FE) + + +@pytest.mark.parametrize("epoch", [0, 42, 2**32 - 1]) +def test_encode_epoch_matches_prefixed_decomposition(epoch: int) -> None: + """An epoch encodes to its value shifted above the message prefix.""" + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE + expected = int_to_base_p(acc, TEST_CONFIG.TWEAK_LEN_FE) + assert encode_epoch(TEST_CONFIG, Uint64(epoch)) == expected + + +def test_encode_epoch_is_injective_over_a_range() -> None: + """Distinct epochs in a range encode to distinct field-element tuples.""" + encodings = {tuple(encode_epoch(TEST_CONFIG, Uint64(i))) for i in range(1000)} + assert len(encodings) == 1000 + + +def test_aborting_decode_known_decomposition() -> None: + """A hand-built quotient decodes to its base-BASE digits, truncated to the dimension.""" + config = TEST_CONFIG + d_value = 5 + fe_list = [Fp(value=config.Q * d_value)] * config.MH_HASH_LEN_FE + + expected_per_fe = [] + remaining = d_value + for _ in range(config.Z): + expected_per_fe.append(remaining % config.BASE) + remaining //= config.BASE + expected = (expected_per_fe * config.MH_HASH_LEN_FE)[: config.DIMENSION] + + assert aborting_decode(config, fe_list) == expected + + +def test_aborting_decode_accepts_largest_valid_element() -> None: + """The element just below the abort threshold decodes successfully.""" + config = TEST_CONFIG + result = aborting_decode(config, [Fp(value=P - 2)] * config.MH_HASH_LEN_FE) + assert result is not None + assert len(result) == config.DIMENSION + assert all(0 <= d < config.BASE for d in result) + + +def test_aborting_decode_rejects_threshold_element() -> None: + """The element equal to the prime minus one triggers the abort.""" + assert aborting_decode(TEST_CONFIG, [Fp(value=P - 1)]) is None + + +def test_message_hash_yields_valid_codeword() -> None: + """The message hash decodes to a codeword of dimension digits in range.""" + config = TEST_CONFIG + parameter = Parameter(data=random_field_elements(config.PARAMETER_LEN)) + randomness = Randomness(data=random_field_elements(config.RAND_LEN_FE)) + + result = message_hash( + TEST_POSEIDON, config, parameter, Uint64(313), randomness, Bytes32(b"\xaa" * 32) + ) + + assert result is not None + assert len(result) == config.DIMENSION + assert all(0 <= digit < config.BASE for digit in result) + + +def test_target_sum_encode_accepts_codeword_on_target_layer() -> None: + """Randomness whose codeword sums to the target is accepted.""" + config = TEST_CONFIG + parameter = _parameter() + # Attempt counter three lands the all-zero message on the target-sum layer. + rho = Randomness(data=int_to_base_p(3, config.RAND_LEN_FE)) + + codeword = target_sum_encode( + TEST_POSEIDON, config, parameter, Bytes32(b"\x00" * 32), rho, Uint64(0) + ) + + # The digits sum to the target of six, landing on the accepted layer. + assert codeword == [3, 0, 3, 0] + + +def test_target_sum_encode_rejects_codeword_off_target_layer() -> None: + """Randomness whose codeword misses the target sum is rejected.""" + config = TEST_CONFIG + parameter = _parameter() + # Attempt counter zero produces a codeword whose digits do not sum to the target. + rho = Randomness(data=int_to_base_p(0, config.RAND_LEN_FE)) + + assert ( + target_sum_encode(TEST_POSEIDON, config, parameter, Bytes32(b"\x00" * 32), rho, Uint64(0)) + is None + ) + + +def test_target_sum_encode_propagates_aborting_decode_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An aborting message hash makes the encode return None before the sum check. + + The aborting decode rejects only the prime-minus-one field element. + That event has probability near one in two billion per element. + It cannot be triggered with real inputs in a test, so the hash result is forced. + """ + monkeypatch.setattr(encoding, "message_hash", lambda *args, **kwargs: None) + + assert ( + target_sum_encode( + TEST_POSEIDON, + TEST_CONFIG, + _parameter(), + Bytes32(b"\x00" * 32), + Randomness(data=int_to_base_p(0, TEST_CONFIG.RAND_LEN_FE)), + Uint64(0), + ) + is None + ) diff --git a/tests/lean_spec/subspecs/xmss/test_field.py b/tests/lean_spec/subspecs/xmss/test_field.py new file mode 100644 index 000000000..58535d1c8 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_field.py @@ -0,0 +1,88 @@ +"""Tests for the field-element decomposition and secure sampling helpers.""" + +import secrets + +import pytest + +from lean_spec.subspecs.koalabear import Fp, P +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.field import ( + int_to_base_p, + random_domain, + random_field_elements, + random_parameter, +) +from lean_spec.subspecs.xmss.types import HashDigestVector, Parameter + + +@pytest.mark.parametrize( + "value, num_limbs, expected_values", + [ + pytest.param(0, 4, [0, 0, 0, 0], id="zero spreads to all-zero limbs"), + pytest.param(123, 4, [123, 0, 0, 0], id="small value in the lowest limb"), + pytest.param(P, 4, [0, 1, 0, 0], id="prime carries into the second limb"), + pytest.param(P - 1, 4, [P - 1, 0, 0, 0], id="largest single-limb value"), + pytest.param(3 * (P**2) + 2 * P + 1, 4, [1, 2, 3, 0], id="mixed multi-limb value"), + pytest.param(P**3 - 1, 3, [P - 1, P - 1, P - 1], id="all limbs saturated"), + ], +) +def test_int_to_base_p_known_decomposition( + value: int, num_limbs: int, expected_values: list[int] +) -> None: + """Decomposition matches hand-computed base-P limbs.""" + assert int_to_base_p(value, num_limbs) == [Fp(value=v) for v in expected_values] + + +def test_int_to_base_p_zero_limbs_returns_empty() -> None: + """Requesting zero limbs yields an empty list.""" + assert int_to_base_p(12345, 0) == [] + + +def test_int_to_base_p_truncates_high_limbs() -> None: + """A value wider than the requested limbs drops its high digits.""" + assert int_to_base_p(P**2 + P + 7, 1) == [Fp(value=7)] + + +def test_int_to_base_p_roundtrip_is_reversible() -> None: + """Decomposing then recomposing recovers the original integer.""" + num_limbs = 5 + original_limbs = [secrets.randbelow(P) for _ in range(num_limbs)] + original_value = sum(val * (P**i) for i, val in enumerate(original_limbs)) + + decomposed = [int(fp) for fp in int_to_base_p(original_value, num_limbs)] + + assert decomposed == original_limbs + + +def test_random_field_elements_length() -> None: + """The sampler returns exactly the requested number of elements.""" + assert len(random_field_elements(7)) == 7 + + +def test_random_field_elements_zero_length() -> None: + """A zero-length request yields an empty list.""" + assert random_field_elements(0) == [] + + +def test_random_field_elements_are_in_field_range() -> None: + """Every sampled element lies in the range zero up to the prime.""" + assert all(0 <= int(fe) < P for fe in random_field_elements(50)) + + +def test_random_field_elements_are_not_constant() -> None: + """A large sample is overwhelmingly unlikely to repeat a single value.""" + assert len({int(fe) for fe in random_field_elements(100)}) > 1 + + +def test_random_parameter_has_parameter_length() -> None: + """A sampled parameter has the configured parameter length.""" + parameter = random_parameter(TEST_CONFIG) + assert isinstance(parameter, Parameter) + assert len(parameter.data) == TEST_CONFIG.PARAMETER_LEN + + +def test_random_domain_has_hash_length() -> None: + """A sampled domain vector has the configured digest length.""" + domain = random_domain(TEST_CONFIG) + assert isinstance(domain, HashDigestVector) + assert len(domain.data) == TEST_CONFIG.HASH_LEN_FE diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py index 2092df1d2..4be16f1cd 100644 --- a/tests/lean_spec/subspecs/xmss/test_interface.py +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -1,13 +1,13 @@ -""" -End-to-end tests for the Generalized XMSS signature scheme. -""" +"""End-to-end tests for the Generalized XMSS signature scheme and its helpers.""" import pytest +from lean_spec.subspecs.xmss import interface from lean_spec.subspecs.xmss.encoding import target_sum_encode from lean_spec.subspecs.xmss.interface import ( TEST_SIGNATURE_SCHEME, GeneralizedXmssScheme, + _expand_activation_time, ) from lean_spec.types import Bytes32, Slot, Uint64 @@ -205,6 +205,79 @@ def test_deterministic_signing() -> None: assert sig1.path.siblings == sig2.path.siblings +@pytest.mark.parametrize( + "log_lifetime, desired_activation, desired_num, expected_start, expected_end", + [ + pytest.param(8, 0, 16, 0, 2, id="boundary request widens to minimum two trees"), + pytest.param(8, 10, 5, 0, 2, id="unaligned request rounds onto tree boundaries"), + pytest.param(8, 0, 100, 0, 7, id="larger request spans seven trees"), + pytest.param(4, 0, 300, 0, 4, id="request wider than lifetime covers all of it"), + pytest.param(8, 32, 16, 2, 4, id="middle request stays put"), + pytest.param(8, 240, 30, 14, 16, id="request near the end slides back to the boundary"), + ], +) +def test_expand_activation_time( + log_lifetime: int, + desired_activation: int, + desired_num: int, + expected_start: int, + expected_end: int, +) -> None: + """The requested window snaps onto whole bottom trees, widened and clamped.""" + assert _expand_activation_time(log_lifetime, desired_activation, desired_num) == ( + expected_start, + expected_end, + ) + + +def test_key_gen_rejects_range_exceeding_lifetime() -> None: + """A requested range past the lifetime is refused.""" + with pytest.raises(ValueError, match="Activation range exceeds the key's lifetime."): + TEST_SIGNATURE_SCHEME.key_gen(Slot(200), Uint64(100)) + + +def test_sign_rejects_slot_outside_activation() -> None: + """Signing a slot the key was never activated for is refused.""" + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)).secret_key + with pytest.raises(ValueError, match="Key is not active for the specified slot."): + TEST_SIGNATURE_SCHEME.sign(secret_key, Slot(200), Bytes32(b"\x42" * 32)) + + +def test_sign_raises_when_no_encoding_found(monkeypatch: pytest.MonkeyPatch) -> None: + """An encoding search that never succeeds raises after exhausting the attempts. + + The encoding is forced to always reject so the retry loop runs to its limit. + """ + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)).secret_key + monkeypatch.setattr(interface, "target_sum_encode", lambda *args, **kwargs: None) + tries = TEST_SIGNATURE_SCHEME.config.MAX_TRIES + with pytest.raises( + RuntimeError, match=f"Failed to find a valid message encoding after {tries} tries." + ): + TEST_SIGNATURE_SCHEME.sign(secret_key, Slot(0), Bytes32(b"\x42" * 32)) + + +def test_sign_raises_on_wrong_codeword_dimension(monkeypatch: pytest.MonkeyPatch) -> None: + """An encoding returning the wrong number of digits raises. + + The encoding is forced to return one digit too few for the scheme dimension. + """ + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)).secret_key + short = [0] * (TEST_SIGNATURE_SCHEME.config.DIMENSION - 1) + monkeypatch.setattr(interface, "target_sum_encode", lambda *args, **kwargs: short) + with pytest.raises( + RuntimeError, match="Encoding is broken: returned too many or too few chunks." + ): + TEST_SIGNATURE_SCHEME.sign(secret_key, Slot(0), Bytes32(b"\x42" * 32)) + + +def test_advance_preparation_is_a_noop_at_the_end() -> None: + """A two-tree key cannot advance and returns the same secret key.""" + leaves = TEST_SIGNATURE_SCHEME.config.LEAVES_PER_BOTTOM_TREE + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(2 * leaves)).secret_key + assert TEST_SIGNATURE_SCHEME.advance_preparation(secret_key) is secret_key + + class TestVerifySecurityBounds: """ Security tests for verify method input validation. diff --git a/tests/lean_spec/subspecs/xmss/test_merkle.py b/tests/lean_spec/subspecs/xmss/test_merkle.py new file mode 100644 index 000000000..5059b9592 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_merkle.py @@ -0,0 +1,387 @@ +"""Tests for the sparse Merkle subtree implementation.""" + +import pytest + +from lean_spec.subspecs.xmss.constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from lean_spec.subspecs.xmss.field import random_domain, random_parameter +from lean_spec.subspecs.xmss.merkle import ( + HashSubTree, + HashTreeLayer, + HashTreeLayers, + combined_path, + verify_path, +) +from lean_spec.subspecs.xmss.poseidon import PROD_POSEIDON, TEST_POSEIDON, PoseidonXmss +from lean_spec.subspecs.xmss.prf import PRFKey +from lean_spec.subspecs.xmss.types import ( + HashDigestList, + HashDigestVector, + HashTreeOpening, + Parameter, + TreeTweak, +) +from lean_spec.types import Uint64 +from lean_spec.types.exceptions import SSZValueError + + +def _run_commit_open_verify_roundtrip( + poseidon: PoseidonXmss, + config: XmssConfig, + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, +) -> None: + """Build a tree, then open and verify every active leaf against its root.""" + parameter = random_parameter(config) + leaves: list[list[HashDigestVector]] = [ + [random_domain(config) for _ in range(leaf_parts_len)] for _ in range(num_leaves) + ] + + leaf_hashes: list[HashDigestVector] = [ + poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=0, index=Uint64(start_index + i)), + leaf_parts, + ) + for i, leaf_parts in enumerate(leaves) + ] + + tree = HashSubTree.new( + poseidon=poseidon, + config=config, + lowest_layer=Uint64(0), + depth=Uint64(depth), + start_index=Uint64(start_index), + parameter=parameter, + lowest_layer_nodes=leaf_hashes, + ) + root = tree.root() + + for i, leaf_parts in enumerate(leaves): + position = Uint64(start_index + i) + opening = tree.path(position) + assert verify_path( + poseidon=poseidon, + config=config, + parameter=parameter, + root=root, + position=position, + leaf_parts=leaf_parts, + opening=opening, + ) + + +@pytest.mark.parametrize( + "num_leaves, depth, start_index, leaf_parts_len", + [ + pytest.param(16, 4, 0, 3, id="Full tree (depth 4)", marks=pytest.mark.slow), + pytest.param(12, 5, 0, 5, id="Half tree, left-aligned (depth 5)", marks=pytest.mark.slow), + pytest.param(16, 5, 16, 2, id="Half tree, right-aligned (depth 5)"), + pytest.param(22, 6, 13, 3, id="Sparse, non-aligned tree (depth 6)", marks=pytest.mark.slow), + pytest.param(2, 2, 2, 6, id="Half tree, right-aligned (small)"), + pytest.param(1, 1, 0, 1, id="Tree with a single leaf at the start"), + pytest.param(1, 1, 1, 1, id="Tree with a single leaf at an odd index"), + pytest.param(16, 5, 7, 2, id="Small sparse tree starting at an odd index"), + ], +) +def test_commit_open_verify_roundtrip( + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, +) -> None: + """A built tree opens and verifies every leaf for various shapes.""" + assert start_index + num_leaves <= (1 << depth) + _run_commit_open_verify_roundtrip( + PROD_POSEIDON, PROD_CONFIG, num_leaves, depth, start_index, leaf_parts_len + ) + + +def test_new_rejects_nodes_overflowing_their_level() -> None: + """A node run that does not fit its level raises with the level and bounds named.""" + with pytest.raises(ValueError, match=r"Overflow at layer 0: start=3, count=2, max=4"): + HashSubTree.new( + poseidon=TEST_POSEIDON, + config=TEST_CONFIG, + lowest_layer=Uint64(0), + depth=Uint64(2), + start_index=Uint64(3), + parameter=random_parameter(TEST_CONFIG), + lowest_layer_nodes=[random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)], + ) + + +def test_new_top_tree_rejects_odd_depth() -> None: + """The top tree requires an even depth for the top-bottom split.""" + with pytest.raises(ValueError, match=r"Depth must be even for top-bottom split, got 7."): + HashSubTree.new_top_tree( + TEST_POSEIDON, TEST_CONFIG, 7, Uint64(0), random_parameter(TEST_CONFIG), [] + ) + + +def test_new_bottom_tree_rejects_odd_depth() -> None: + """The bottom tree requires an even depth for the top-bottom split.""" + with pytest.raises(ValueError, match=r"Depth must be even for top-bottom split, got 7."): + HashSubTree.new_bottom_tree( + TEST_POSEIDON, TEST_CONFIG, 7, Uint64(0), random_parameter(TEST_CONFIG), [] + ) + + +def test_new_bottom_tree_rejects_wrong_leaf_count() -> None: + """The bottom tree requires exactly the square-root-of-lifetime leaves.""" + with pytest.raises(ValueError, match=r"Expected 16 leaves for depth=8, got 0."): + HashSubTree.new_bottom_tree( + TEST_POSEIDON, TEST_CONFIG, 8, Uint64(0), random_parameter(TEST_CONFIG), [] + ) + + +def test_root_rejects_empty_subtree() -> None: + """A subtree with no layers has no root.""" + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[]), + ) + with pytest.raises(ValueError, match=r"Empty subtree has no root."): + subtree.root() + + +def test_root_rejects_empty_top_layer() -> None: + """A subtree whose top layer holds no nodes has no root.""" + empty_layer = HashTreeLayer(start_index=Uint64(0), nodes=HashDigestList(data=[])) + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[empty_layer]), + ) + with pytest.raises(ValueError, match=r"Top layer is empty."): + subtree.root() + + +def test_path_rejects_empty_subtree() -> None: + """Opening a path on a subtree with no layers raises.""" + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[]), + ) + with pytest.raises(ValueError, match=r"Empty subtree."): + subtree.path(Uint64(0)) + + +def test_path_rejects_position_out_of_bounds() -> None: + """A position outside the lowest layer's stored range raises.""" + layer = HashTreeLayer( + start_index=Uint64(0), + nodes=HashDigestList(data=[random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)]), + ) + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[layer]), + ) + with pytest.raises(ValueError, match=r"Position 5 out of bounds."): + subtree.path(Uint64(5)) + + +def test_path_rejects_sibling_out_of_bounds() -> None: + """A non-root layer too small to hold the needed sibling raises.""" + leaf_layer = HashTreeLayer( + start_index=Uint64(0), + nodes=HashDigestList(data=[random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)]), + ) + # The middle layer lacks the sibling at index one. + middle_layer = HashTreeLayer( + start_index=Uint64(0), nodes=HashDigestList(data=[random_domain(TEST_CONFIG)]) + ) + root_layer = HashTreeLayer( + start_index=Uint64(0), nodes=HashDigestList(data=[random_domain(TEST_CONFIG)]) + ) + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[leaf_layer, middle_layer, root_layer]), + ) + with pytest.raises(ValueError, match=r"Sibling index 1 out of bounds."): + subtree.path(Uint64(0)) + + +@pytest.fixture(scope="module") +def prf_trees() -> tuple[Parameter, HashSubTree, HashSubTree, HashSubTree]: + """Build a top tree over two prf-derived bottom trees for combined-path tests.""" + config = TEST_CONFIG + prf_key = PRFKey.generate() + parameter = random_parameter(config) + bottom_zero = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + bottom_one = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(1), + parameter=parameter, + ) + top = HashSubTree.new_top_tree( + TEST_POSEIDON, + config, + config.LOG_LIFETIME, + Uint64(0), + parameter, + [bottom_zero.root(), bottom_one.root()], + ) + return parameter, top, bottom_zero, bottom_one + + +def test_combined_path_authenticates_leaf_to_global_root( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """A combined opening spans the full depth from leaf to global root.""" + _, top, bottom_zero, _ = prf_trees + opening = combined_path(top, bottom_zero, Uint64(0)) + assert len(opening.siblings) == TEST_CONFIG.LOG_LIFETIME + + +def test_combined_path_rejects_depth_mismatch( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """Top and bottom trees of disagreeing depth cannot be stitched.""" + _, top, bottom_zero, _ = prf_trees + mismatched = bottom_zero.model_copy(update={"depth": Uint64(6)}) + with pytest.raises(ValueError, match=r"Depth mismatch: top=8, bottom=6."): + combined_path(top, mismatched, Uint64(0)) + + +def test_combined_path_rejects_odd_depth( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """Stitching requires an even depth.""" + _, top, bottom_zero, _ = prf_trees + odd_top = top.model_copy(update={"depth": Uint64(7)}) + odd_bottom = bottom_zero.model_copy(update={"depth": Uint64(7)}) + with pytest.raises(ValueError, match=r"Depth must be even, got 7."): + combined_path(odd_top, odd_bottom, Uint64(7)) + + +def test_combined_path_rejects_wrong_bottom_tree( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """A position belonging to a sibling bottom tree is refused.""" + _, top, bottom_zero, _ = prf_trees + with pytest.raises( + ValueError, + match=r"Wrong bottom tree: position 16 needs start 16, got 0.", + ): + combined_path(top, bottom_zero, Uint64(16)) + + +def test_from_prf_key_builds_a_bottom_tree() -> None: + """A prf-derived bottom tree has the expected depth, layers, and leaf count.""" + config = TEST_CONFIG + bottom_tree = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=PRFKey.generate(), + bottom_tree_index=Uint64(0), + parameter=random_parameter(config), + ) + assert bottom_tree.depth == Uint64(config.LOG_LIFETIME) + assert bottom_tree.lowest_layer == Uint64(0) + assert len(bottom_tree.layers[-1].nodes) == 1 + assert len(bottom_tree.layers[0].nodes) == config.LEAVES_PER_BOTTOM_TREE + + +def test_from_prf_key_is_deterministic() -> None: + """The same seed and index rebuild the same bottom-tree root.""" + config = TEST_CONFIG + prf_key = PRFKey.generate() + parameter = random_parameter(config) + first = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + second = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + assert first.root() == second.root() + + +def test_from_prf_key_distinct_indices_give_distinct_roots() -> None: + """Different bottom-tree indices produce different roots.""" + config = TEST_CONFIG + prf_key = PRFKey.generate() + parameter = random_parameter(config) + tree_zero = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + tree_one = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(1), + parameter=parameter, + ) + assert tree_zero.root() != tree_one.root() + + +def test_verify_path_rejects_excessive_depth_at_ssz_level() -> None: + """ + The SSZ type system caps an opening at the layer limit. + + The defensive depth guard inside verification cannot be reached through a + well-formed opening, since the digest list rejects more than its limit. + """ + with pytest.raises(SSZValueError): + HashDigestList(data=[random_domain(PROD_CONFIG) for _ in range(33)]) + + +@pytest.mark.parametrize("position", [16, 100]) +def test_verify_path_rejects_position_exceeding_capacity(position: int) -> None: + """A position at or beyond two-to-the-depth fails without raising.""" + siblings = [random_domain(PROD_CONFIG) for _ in range(4)] + opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) + assert ( + verify_path( + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, + parameter=random_parameter(PROD_CONFIG), + root=random_domain(PROD_CONFIG), + position=Uint64(position), + leaf_parts=[random_domain(PROD_CONFIG)], + opening=opening, + ) + is False + ) + + +def test_verify_path_accepts_boundary_position_without_raising() -> None: + """The maximum valid position for the depth does not trip the bounds guard.""" + siblings = [random_domain(PROD_CONFIG) for _ in range(4)] + opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) + result = verify_path( + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, + parameter=random_parameter(PROD_CONFIG), + root=random_domain(PROD_CONFIG), + position=Uint64(15), + leaf_parts=[random_domain(PROD_CONFIG)], + opening=opening, + ) + assert isinstance(result, bool) diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py deleted file mode 100644 index 9b3d59823..000000000 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Tests for the sparse Merkle tree implementation.""" - -import pytest - -from lean_spec.subspecs.xmss.constants import PROD_CONFIG, XmssConfig -from lean_spec.subspecs.xmss.field import random_domain, random_parameter -from lean_spec.subspecs.xmss.merkle import HashSubTree, verify_path -from lean_spec.subspecs.xmss.poseidon import PROD_POSEIDON, PoseidonXmss -from lean_spec.subspecs.xmss.types import ( - HashDigestList, - HashDigestVector, - HashTreeOpening, - TreeTweak, -) -from lean_spec.types import Uint64 -from lean_spec.types.exceptions import SSZValueError - - -def _run_commit_open_verify_roundtrip( - poseidon: PoseidonXmss, - config: XmssConfig, - num_leaves: int, - depth: int, - start_index: int, - leaf_parts_len: int, -) -> None: - """ - A helper function to perform a full Merkle tree roundtrip test. - - The process is as follows: - 1. Generate random leaf data. - 2. Hash the leaves to create layer 0 of the tree. - 3. Build the full Merkle tree and get its root (commit). - 4. For each leaf, generate an authentication path (open). - 5. Verify that each path is valid for its corresponding leaf and root. - - Args: - poseidon: Cached Poseidon1 engine. - config: Active XMSS configuration. - num_leaves: The number of active leaves in the tree. - depth: The total depth of the Merkle tree. - start_index: The starting index of the first active leaf. - leaf_parts_len: The number of digests that constitute a single leaf. - """ - # SETUP: Generate a random parameter and the raw leaf data. - parameter = random_parameter(config) - leaves: list[list[HashDigestVector]] = [ - [random_domain(config) for _ in range(leaf_parts_len)] for _ in range(num_leaves) - ] - - # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. - leaf_hashes: list[HashDigestVector] = [ - poseidon.tweak_hash( - config, - parameter, - TreeTweak(level=0, index=Uint64(start_index + i)), - leaf_parts, - ) - for i, leaf_parts in enumerate(leaves) - ] - - # COMMIT: Build the Merkle tree from the leaf hashes. - tree = HashSubTree.new( - poseidon=poseidon, - config=config, - lowest_layer=Uint64(0), - depth=Uint64(depth), - start_index=Uint64(start_index), - parameter=parameter, - lowest_layer_nodes=leaf_hashes, - ) - root = tree.root() - - # OPEN & VERIFY: For each leaf, generate and verify its path. - for i, leaf_parts in enumerate(leaves): - position = Uint64(start_index + i) - opening = tree.path(position) - is_valid = verify_path( - poseidon=poseidon, - config=config, - parameter=parameter, - root=root, - position=position, - leaf_parts=leaf_parts, - opening=opening, - ) - assert is_valid, f"Verification failed for leaf at position {position}" - - -@pytest.mark.parametrize( - "num_leaves, depth, start_index, leaf_parts_len", - [ - pytest.param(16, 4, 0, 3, id="Full tree (depth 4)", marks=pytest.mark.slow), - pytest.param(12, 5, 0, 5, id="Half tree, left-aligned (depth 5)", marks=pytest.mark.slow), - pytest.param(16, 5, 16, 2, id="Half tree, right-aligned (depth 5)"), - pytest.param(22, 6, 13, 3, id="Sparse, non-aligned tree (depth 6)", marks=pytest.mark.slow), - pytest.param(2, 2, 2, 6, id="Half tree, right-aligned (small)"), - pytest.param(1, 1, 0, 1, id="Tree with a single leaf at the start"), - pytest.param(1, 1, 1, 1, id="Tree with a single leaf at an odd index"), - pytest.param(16, 5, 7, 2, id="Small sparse tree starting at an odd index"), - ], -) -def test_commit_open_verify_roundtrip( - num_leaves: int, - depth: int, - start_index: int, - leaf_parts_len: int, -) -> None: - """Tests the Merkle tree logic for various configurations.""" - # Ensure the test case parameters are valid for the specified tree depth. - assert start_index + num_leaves <= (1 << depth) - - _run_commit_open_verify_roundtrip( - PROD_POSEIDON, PROD_CONFIG, num_leaves, depth, start_index, leaf_parts_len - ) - - -class TestVerifyPathSecurityBounds: - """ - Security tests for verify_path input validation. - - Verification functions must return False (not raise) on attacker-controlled invalid input. - This prevents denial-of-service via malformed signatures. - """ - - def test_ssz_validation_rejects_excessive_depth(self) -> None: - """ - SSZ type system rejects openings with depth > 32. - - HashDigestList has a LIMIT of 32, so the type system prevents - creating malformed openings at the SSZ level. The check in - verify_path is defense-in-depth for deserialized data. - """ - # Attempting to create a list with 33 siblings raises at the type level. - excessive_siblings = [random_domain(PROD_CONFIG) for _ in range(33)] - with pytest.raises(SSZValueError): - HashDigestList(data=excessive_siblings) - - def test_rejects_position_exceeding_tree_capacity(self) -> None: - """verify_path returns False when position >= 2^depth.""" - parameter = random_parameter(PROD_CONFIG) - - root = random_domain(PROD_CONFIG) - leaf_parts = [random_domain(PROD_CONFIG)] - - # Create an opening with depth=4 (supports positions 0-15). - siblings = [random_domain(PROD_CONFIG) for _ in range(4)] - opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) - - # Position 16 is out of bounds for depth 4 (capacity = 2^4 = 16). - result = verify_path( - poseidon=PROD_POSEIDON, - config=PROD_CONFIG, - parameter=parameter, - root=root, - position=Uint64(16), - leaf_parts=leaf_parts, - opening=opening, - ) - assert result is False - - # Position 100 is also out of bounds. - result = verify_path( - poseidon=PROD_POSEIDON, - config=PROD_CONFIG, - parameter=parameter, - root=root, - position=Uint64(100), - leaf_parts=leaf_parts, - opening=opening, - ) - assert result is False - - def test_valid_position_at_boundary(self) -> None: - """verify_path accepts position at maximum valid value (2^depth - 1).""" - parameter = random_parameter(PROD_CONFIG) - - root = random_domain(PROD_CONFIG) - leaf_parts = [random_domain(PROD_CONFIG)] - - # Create an opening with depth=4. - siblings = [random_domain(PROD_CONFIG) for _ in range(4)] - opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) - - # Position 15 is the maximum valid position for depth 4. - # This should not return False due to bounds check (may still fail root check). - result = verify_path( - poseidon=PROD_POSEIDON, - config=PROD_CONFIG, - parameter=parameter, - root=root, - position=Uint64(15), - leaf_parts=leaf_parts, - opening=opening, - ) - # Result may be False due to wrong root, but importantly it didn't raise. - assert isinstance(result, bool) diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py deleted file mode 100644 index 6d42a43e4..000000000 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ /dev/null @@ -1,122 +0,0 @@ -""" -Tests for the message hashing and aborting hypercube encoding logic. -""" - -from lean_spec.subspecs.koalabear import Fp, P -from lean_spec.subspecs.xmss.constants import ( - TEST_CONFIG, - TWEAK_PREFIX_MESSAGE, -) -from lean_spec.subspecs.xmss.encoding import ( - aborting_decode, - encode_epoch, - encode_message, - message_hash, -) -from lean_spec.subspecs.xmss.field import int_to_base_p, random_field_elements -from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON -from lean_spec.subspecs.xmss.types import Parameter, Randomness -from lean_spec.types import Bytes32, Uint64 - - -def test_encode_message() -> None: - """Tests encode_message with various message patterns.""" - config = TEST_CONFIG - - # All-zero message - msg_zeros = Bytes32(b"\x00" * 32) - encoded_zeros = encode_message(config, msg_zeros) - assert len(encoded_zeros) == config.MSG_LEN_FE - assert all(fe == Fp(value=0) for fe in encoded_zeros) - - # All-max message (0xff) - msg_max = Bytes32(b"\xff" * 32) - acc = int.from_bytes(msg_max, "little") - expected_max = int_to_base_p(acc, config.MSG_LEN_FE) - assert encode_message(config, msg_max) == expected_max - - -def test_encode_epoch() -> None: - """ - Tests encode_epoch for correctness and injectivity. - """ - config = TEST_CONFIG - - # Test specific values from the Rust reference tests. - test_epochs = [0, 42, 2**32 - 1] - for epoch in test_epochs: - acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE - expected = int_to_base_p(acc, config.TWEAK_LEN_FE) - assert encode_epoch(config, Uint64(epoch)) == expected - - # Test for injectivity. It is highly unlikely for a collision to occur - # with a few random samples if the encoding is injective. - num_trials = 1000 - seen_encodings: set[tuple[Fp, ...]] = set() - for i in range(num_trials): - encoding = tuple(encode_epoch(config, Uint64(i))) - assert encoding not in seen_encodings - seen_encodings.add(encoding) - - -def test_aborting_decode_known_decomposition() -> None: - """Verifies aborting decode with a hand-computed example.""" - config = TEST_CONFIG - - # Pick an arbitrary quotient multiplier to build a valid field element. - d_value = 5 - fe_list = [Fp(value=config.Q * d_value)] * config.MH_HASH_LEN_FE - result = aborting_decode(config, fe_list) - assert result is not None - assert len(result) == config.DIMENSION - - # Each FE decomposes d_value into Z base-BASE digits (LSB first), - # then the first DIMENSION digits are taken across all FEs. - digits_per_fe = [] - remaining = d_value - for _ in range(config.Z): - digits_per_fe.append(remaining % config.BASE) - remaining //= config.BASE - all_digits = (digits_per_fe * config.MH_HASH_LEN_FE)[: config.DIMENSION] - assert result == all_digits - - -def test_aborting_decode_boundary() -> None: - """Tests that FE = P-2 succeeds and FE = P-1 aborts.""" - config = TEST_CONFIG - - # P - 2 is the largest valid value (just below Q * BASE^Z = P - 1). - fe_valid = [Fp(value=P - 2)] * config.MH_HASH_LEN_FE - result = aborting_decode(config, fe_valid) - assert result is not None - assert len(result) == config.DIMENSION - assert all(0 <= d < config.BASE for d in result) - - # P - 1 triggers the abort (A_i >= Q * BASE^Z). - fe_abort = [Fp(value=P - 1)] - result = aborting_decode(config, fe_abort) - assert result is None - - -def test_apply_output_is_valid_codeword() -> None: - """ - Tests that the output of message_hash is None or a valid codeword with - DIMENSION digits each in [0, BASE-1]. - """ - config = TEST_CONFIG - - # Setup with random inputs. - parameter = Parameter(data=random_field_elements(config.PARAMETER_LEN)) - epoch = Uint64(313) - randomness = Randomness(data=random_field_elements(config.RAND_LEN_FE)) - message = Bytes32(b"\xaa" * 32) - - # Call the message hash function. - result = message_hash(TEST_POSEIDON, config, parameter, epoch, randomness, message) - - # The aborting decode may return None, but in practice it almost never does. - assert result is not None - - # Verify the properties of the output codeword. - assert len(result) == config.DIMENSION - assert all(0 <= coord < config.BASE for coord in result) diff --git a/tests/lean_spec/subspecs/xmss/test_poseidon.py b/tests/lean_spec/subspecs/xmss/test_poseidon.py new file mode 100644 index 000000000..7dd605c97 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_poseidon.py @@ -0,0 +1,184 @@ +"""Tests for the Poseidon1 hash engine wrapper used by the XMSS scheme.""" + +import pytest + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.field import random_domain +from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON +from lean_spec.subspecs.xmss.types import ( + ChainTweak, + HashDigestVector, + Parameter, + TreeTweak, +) +from lean_spec.types import Uint64 + + +def _parameter() -> Parameter: + """Return a fixed public parameter for hashing tests.""" + return Parameter(data=[Fp(value=1)] * TEST_CONFIG.PARAMETER_LEN) + + +@pytest.mark.parametrize("width", [16, 24]) +def test_get_engine_caches_supported_widths(width: int) -> None: + """A supported width yields the same engine instance on repeated calls.""" + assert TEST_POSEIDON._get_engine(width) is TEST_POSEIDON._get_engine(width) + + +@pytest.mark.parametrize("width", [0, 8, 15, 17, 32]) +def test_get_engine_rejects_unsupported_width(width: int) -> None: + """An unsupported width raises with the offending value named.""" + with pytest.raises(ValueError, match=f"Width must be 16 or 24, got {width}"): + TEST_POSEIDON._get_engine(width) + + +@pytest.mark.parametrize("width", [16, 24]) +def test_compress_returns_requested_output_length(width: int) -> None: + """Compression returns exactly the requested number of output elements.""" + result = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], width, 8) + assert len(result) == 8 + + +def test_compress_truncates_to_short_output() -> None: + """A short output length yields a truncated digest.""" + result = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], 16, 1) + assert len(result) == 1 + + +def test_compress_is_deterministic() -> None: + """The same input compresses to the same output.""" + a = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], 16, 8) + b = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], 16, 8) + assert a == b + + +def test_compress_rejects_output_longer_than_input() -> None: + """Requesting more output than the raw input length raises.""" + with pytest.raises(ValueError, match="Input vector is too short for requested output length."): + TEST_POSEIDON.compress([Fp(value=1), Fp(value=2)], 16, 8) + + +def test_safe_domain_separator_returns_capacity_length() -> None: + """The domain separator returns a vector of the requested capacity length.""" + assert len(TEST_POSEIDON.safe_domain_separator([5, 2, 4, 8], 9)) == 9 + + +def test_safe_domain_separator_distinguishes_shapes() -> None: + """Different length parameters produce different capacity values.""" + assert TEST_POSEIDON.safe_domain_separator([1, 2], 9) != TEST_POSEIDON.safe_domain_separator( + [2, 1], 9 + ) + + +def test_sponge_returns_requested_output_length() -> None: + """The sponge returns exactly the requested number of output elements.""" + capacity = TEST_POSEIDON.safe_domain_separator([1, 2, 3, 4], 9) + assert len(TEST_POSEIDON.sponge([Fp(value=1)] * 5, capacity, 8, 24)) == 8 + + +def test_sponge_squeezes_more_than_one_rate_block() -> None: + """Requesting more output than one rate block permutes again to squeeze enough.""" + capacity = TEST_POSEIDON.safe_domain_separator([1, 2, 3, 4], 9) + rate = 24 - len(capacity) + assert len(TEST_POSEIDON.sponge([Fp(value=1)] * 5, capacity, rate + 1, 24)) == rate + 1 + + +def test_sponge_rejects_capacity_not_smaller_than_width() -> None: + """A capacity that fills the whole state leaves no rate slot and raises.""" + with pytest.raises(ValueError, match="Capacity length must be smaller than the state width."): + TEST_POSEIDON.sponge([Fp(value=1)], [Fp(value=0)] * 16, 1, 16) + + +def test_tweak_hash_chain_uses_width_sixteen_compression() -> None: + """A single digest input hashes through width-sixteen compression.""" + result = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, + _parameter(), + ChainTweak(epoch=Uint64(0), chain_index=1, step=1), + [random_domain(TEST_CONFIG)], + ) + assert isinstance(result, HashDigestVector) + assert len(result.data) == TEST_CONFIG.HASH_LEN_FE + + +def test_tweak_hash_node_uses_width_twenty_four_compression() -> None: + """Two digest inputs hash through width-twenty-four compression.""" + result = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, + _parameter(), + TreeTweak(level=1, index=Uint64(0)), + [random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)], + ) + assert len(result.data) == TEST_CONFIG.HASH_LEN_FE + + +def test_tweak_hash_leaf_uses_sponge_mode() -> None: + """More than two digest inputs hash through sponge mode.""" + parts = [random_domain(TEST_CONFIG) for _ in range(TEST_CONFIG.DIMENSION)] + result = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, _parameter(), TreeTweak(level=0, index=Uint64(0)), parts + ) + assert len(result.data) == TEST_CONFIG.HASH_LEN_FE + + +def test_tweak_hash_chain_and_tree_tweaks_are_domain_separated() -> None: + """A chain tweak and a tree tweak over one digest produce different hashes.""" + digest = random_domain(TEST_CONFIG) + chain = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, _parameter(), ChainTweak(epoch=Uint64(0), chain_index=0, step=1), [digest] + ) + # A single-part tree tweak still routes through width-sixteen compression. + tree = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, _parameter(), TreeTweak(level=0, index=Uint64(0)), [digest] + ) + assert chain != tree + + +def test_hash_chain_zero_steps_returns_start_digest() -> None: + """Walking zero steps returns the starting digest unchanged.""" + start = random_domain(TEST_CONFIG) + result = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=_parameter(), + epoch=Uint64(0), + chain_index=0, + start_step=0, + num_steps=0, + start_digest=start, + ) + assert result == start + + +def test_hash_chain_is_composable() -> None: + """Walking two steps equals walking one step then one more.""" + parameter = _parameter() + start = random_domain(TEST_CONFIG) + two = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=parameter, + epoch=Uint64(0), + chain_index=0, + start_step=0, + num_steps=2, + start_digest=start, + ) + one = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=parameter, + epoch=Uint64(0), + chain_index=0, + start_step=0, + num_steps=1, + start_digest=start, + ) + one_more = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=parameter, + epoch=Uint64(0), + chain_index=0, + start_step=1, + num_steps=1, + start_digest=one, + ) + assert two == one_more diff --git a/tests/lean_spec/subspecs/xmss/test_security_levels.py b/tests/lean_spec/subspecs/xmss/test_security_levels.py deleted file mode 100644 index 86b4a46c8..000000000 --- a/tests/lean_spec/subspecs/xmss/test_security_levels.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Validates that XMSS parameter choices achieve adequate classical and quantum security. - -Based on: - -- [DKKW25c] "Hash-Based Multi-Signatures for Post-Quantum Ethereum" - (https://eprint.iacr.org/2025/055.pdf) -- [HKKTW26] "Aborting Random Oracles" - (https://eprint.iacr.org/2026/016) - -The security analysis follows the framework of [DKKW25c] Section 6. Theorem 1 -gives an advantage bound as the sum of five terms. Each term divided by attacker -running time must be at most 2^{-(k + log5)}, yielding four independent -constraints (Parameter Requirements 2 and 3): - -1. Digest (SM-UD/SM-PRE via Eq 8-9 / Eq 15) -2. Public parameter (SM-TCR via Eq 6-7 / Eq 16) -3. Message hash (SM-rTCR via Eq 10 / Eq 13) -4. Randomness (SM-rTCR via Eq 10 / Eq 14) - -The abort correction from [HKKTW26] Corollary 1 and Remark 14 adjusts the -message hash bound: the aborting decode effectively enlarges the output space -to |H|/(1 - theta), where theta is the abort probability. -""" - -import math - -import pytest - -from lean_spec.subspecs.koalabear import P -from lean_spec.subspecs.xmss.constants import PROD_CONFIG, XmssConfig - - -def _calculate_layer_size(w: int, v: int, d: int) -> int: - """ - Calculates a hypercube layer's size using inclusion-exclusion. - - Counts integer solutions to x_1 + ... + x_v = k with 0 <= x_i <= w-1, - where k = v*(w-1) - d. - """ - coord_sum = v * (w - 1) - d - return sum( - ((-1) ** s) * math.comb(v, s) * math.comb(coord_sum - s * w + v - 1, v - 1) - for s in range(coord_sum // w + 1) - ) - - -def _compute_security_levels(config: XmssConfig) -> dict[str, float]: - """ - Computes classical and quantum security levels for an XMSS configuration. - - Returns a dict with keys: - - - k_classical: effective classical security (bits) - - k_quantum: effective quantum security (bits) - - expected_attempts: expected signing attempts per message - - signing_failure_log2: log2 of probability that all MAX_TRIES attempts fail - """ - v = config.DIMENSION - w_bits = int(math.log2(config.BASE)) - base = config.BASE - - # Bit sizes of the parameter spaces. - # - # Each KoalaBear field element contributes floor(log2(P)) = 31 bits. - fe_bits = 31 - bits_digest = config.HASH_LEN_FE * fe_bits - bits_param = config.PARAMETER_LEN * fe_bits - bits_rand = config.RAND_LEN_FE * fe_bits - - # Raw message hash output: v chunks of w bits each. - bits_msg = v * w_bits - - # Abort correction from [HKKTW26] Corollary 1, Remark 14. - # - # Each field element aborts iff A_i >= Q * BASE^Z (i.e., A_i == P - 1). - # The non-abort probability per FE is (Q * BASE^Z) / P = (P - 1) / P. - # Over ell = ceil(v / Z) field elements, the total non-abort probability is: - # (1 - theta) = ((P - 1) / P) ^ ell - # - # The aborting rTCR bound ([HKKTW26] Corollary 1) gains a factor (1 - theta), - # which is equivalent to hashing into a space of size |H| / (1 - theta). - # This adds -log2(1 - theta) bits to the effective message hash output. - wz = base**config.Z - q = config.Q - ell = math.ceil(v / config.Z) - - non_abort_total = ((q * wz) / P) ** ell - abort_correction_bits = -math.log2(non_abort_total) - - bits_msg_eff = bits_msg + abort_correction_bits - - # Useful logarithmic constants. - log5 = math.log2(5) - log12 = math.log2(12) - log3 = math.log2(3) - log_lifetime = math.log2(config.LIFETIME) - logv = math.log2(v) - log_max_tries = math.log2(config.MAX_TRIES) - logqs = math.log2(config.LIFETIME) - - # Classical security: minimum over four attack surfaces. - # - # Each bound derives from the requirement that each of the five terms in - # Theorem 1 satisfies Adv_i / T(A) <= 2^{-(k_C + log5)}. - k_classical = min( - # [DKKW25c] Eq (15): SM-UD + SM-PRE on the digest hash Th. - bits_digest - log5 - 2 * w_bits - log_lifetime - logv, - # [DKKW25c] Eq (16): SM-TCR on the public parameter space. - bits_param - log5 - 3, - # [DKKW25c] Eq (13) + [HKKTW26] Corollary 1: SM-rTCR on message hash. - bits_msg_eff - log5 - 1, - # [DKKW25c] Eq (14): SM-rTCR randomness reprogramming. - bits_rand - log5 - logqs - log_max_tries - 1, - ) - - # Quantum security: minimum over four attack surfaces. - # - # Uses quantum ROM bounds from [DKKW25c] Table 1. - k_quantum = min( - # [DKKW25c] Eq (15), quantum: digest hash. - bits_digest / 2 - log5 - 2 * w_bits - log_lifetime - logv - log12, - # [DKKW25c] Eq (16), quantum: public parameter. - (bits_param - 5) / 2 - log5 - 2, - # [DKKW25c] Eq (13) + [HKKTW26] Corollary 1, quantum: message hash. - (bits_msg_eff - 3) / 2 - log5 - 1, - # [DKKW25c] Eq (14), quantum: randomness reprogramming. - (bits_rand - logqs) / 2 - log5 - log3 - log_max_tries, - ) - - # Expected signing attempts for target-sum encoding. - # - # [DKKW25c] Construction 6, Lemma 7: the number of valid codewords is - # |C| = #{x in Z_W^v : sum(x_i) = T}, the layer size at distance - # d = v*(W-1) - T from the sink vertex. The inclusion-exclusion formula - # from _calculate_layer_size gives |C|. - # - # Success probability per attempt = P(no abort) * P(target layer | no abort). - d = v * (base - 1) - config.TARGET_SUM - layer_size = _calculate_layer_size(base, v, d) - layer_prob = layer_size / base**v - success_prob = non_abort_total * layer_prob - expected_attempts = 1 / success_prob - - # [DKKW25c] Lemma 3: correctness error is delta^K where delta = 1 - success_prob. - signing_failure_log2 = config.MAX_TRIES * math.log2(1 - success_prob) - - return { - "k_classical": k_classical, - "k_quantum": k_quantum, - "expected_attempts": expected_attempts, - "signing_failure_log2": signing_failure_log2, - } - - -def test_prod_classical_security() -> None: - """Production parameters must achieve at least 128-bit classical security.""" - levels = _compute_security_levels(PROD_CONFIG) - assert levels["k_classical"] >= 128, ( - f"Classical security {levels['k_classical']:.2f} bits is below 128" - ) - - -def test_prod_quantum_security() -> None: - """Production parameters must achieve at least 64-bit quantum security.""" - levels = _compute_security_levels(PROD_CONFIG) - assert levels["k_quantum"] >= 64, f"Quantum security {levels['k_quantum']:.2f} bits is below 64" - - -def test_prod_signing_efficiency() -> None: - """Signing must succeed within a reasonable number of attempts on average.""" - levels = _compute_security_levels(PROD_CONFIG) - - # Expected attempts should be manageable (< 1000). - assert levels["expected_attempts"] < 1000, ( - f"Expected {levels['expected_attempts']:.2f} signing attempts is too high" - ) - - # The probability of MAX_TRIES consecutive failures must be astronomically small. - # log2(failure_prob) < -128 means failure probability < 2^{-128}. - assert levels["signing_failure_log2"] < -128, ( - f"Signing failure probability 2^{levels['signing_failure_log2']:.2f} is too high" - ) - - -def test_prod_abort_probability_is_negligible() -> None: - """ - The aborting decode rejection probability must be negligible. - - From [HKKTW26] Section 6.1: each FE has abort probability 1/P. - Over ceil(v/Z) FEs, the total abort probability is approximately - ceil(v/Z) / P. - """ - config = PROD_CONFIG - ell = math.ceil(config.DIMENSION / config.Z) - - # Per-FE non-abort probability: (Q * BASE^Z) / P = (P - 1) / P. - non_abort_per_fe = (config.Q * config.BASE**config.Z) / P - total_non_abort = non_abort_per_fe**ell - - # The abort probability should be less than 2^{-28} (~3.7e-9). - abort_prob = 1 - total_non_abort - assert abort_prob < 2**-28, f"Abort probability {abort_prob:.2e} is not negligible" - - -def test_prod_decomposition_invariant() -> None: - """ - Validates the fundamental relationship Q * BASE^Z == P - 1. - - From [HKKTW26] Section 6.1: for KoalaBear, P - 1 = 2^24 * 127. - With BASE = 2^w, the decomposition requires w | 24 so that - Z = 24 / w digits can be extracted from each field element. - """ - config = PROD_CONFIG - - # Core decomposition invariant (also checked at config construction time). - assert config.Q * config.BASE**config.Z == P - 1 - - # w must divide 24 for the rejection sampling to work with KoalaBear. - # - # P - 1 = 2^24 * 127, and BASE = 2^w, so we need w | 24. - w_bits = int(math.log2(config.BASE)) - assert config.BASE == 2**w_bits, "BASE must be a power of 2" - assert 24 % w_bits == 0, f"w={w_bits} must divide 24" - - # Z must equal 24 / w for the optimal decomposition (alpha = 1). - assert config.Z == 24 // w_bits, f"Z={config.Z} must equal 24/w={24 // w_bits}" - - -def test_prod_mh_hash_len_is_consistent() -> None: - """ - The Poseidon output length must produce enough digits to cover DIMENSION. - - From [HKKTW26] Section 6.1: ell = ceil(v / z) field elements produce - ell * z >= v base-w digits. - """ - config = PROD_CONFIG - assert config.MH_HASH_LEN_FE * config.Z >= config.DIMENSION - - -def test_prod_binding_constraint_is_message_hash() -> None: - """ - Verify the binding (smallest) constraint is the message hash for both - classical and quantum security, matching the design intent from [DKKW25c]. - """ - config = PROD_CONFIG - v = config.DIMENSION - w_bits = int(math.log2(config.BASE)) - fe_bits = 31 - - bits_digest = config.HASH_LEN_FE * fe_bits - bits_param = config.PARAMETER_LEN * fe_bits - bits_rand = config.RAND_LEN_FE * fe_bits - bits_msg = v * w_bits - - log5 = math.log2(5) - log_lifetime = math.log2(config.LIFETIME) - logv = math.log2(v) - log_max_tries = math.log2(config.MAX_TRIES) - - # Classical: the message hash bound v*w - log5 - 1 should be the tightest. - classical_bounds = [ - bits_digest - log5 - 2 * w_bits - log_lifetime - logv, - bits_param - log5 - 3, - bits_msg - log5 - 1, - bits_rand - log5 - log_lifetime - log_max_tries - 1, - ] - assert classical_bounds.index(min(classical_bounds)) == 2, ( - "Classical binding constraint should be message hash (index 2)" - ) - - -@pytest.mark.parametrize( - "param_name, value", - [ - ("DIMENSION", 46), - ("BASE", 8), - ("Z", 8), - ("Q", 127), - ("TARGET_SUM", 200), - ("LOG_LIFETIME", 32), - ("PARAMETER_LEN", 5), - ("TWEAK_LEN_FE", 2), - ("MSG_LEN_FE", 9), - ("RAND_LEN_FE", 7), - ("HASH_LEN_FE", 8), - ("CAPACITY", 9), - ], -) -def test_prod_config_matches_reference(param_name: str, value: int) -> None: - """ - Guards against accidental parameter drift. - - These values must match the canonical Rust implementation (leanSig). - """ - assert getattr(PROD_CONFIG, param_name) == value - - -def test_print_security_summary(capsys: pytest.CaptureFixture[str]) -> None: - """Prints a human-readable summary of the security analysis (informational).""" - levels = _compute_security_levels(PROD_CONFIG) - print("\n--- XMSS Production Security Summary ---") - print(f"Classical security: {levels['k_classical']:.2f} bits") - print(f"Quantum security: {levels['k_quantum']:.2f} bits") - print(f"Expected sign attempts: {levels['expected_attempts']:.2f}") - print(f"Signing failure (log2): {levels['signing_failure_log2']:.2f}") - print("----------------------------------------") diff --git a/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py b/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py deleted file mode 100644 index 02ce9b3e9..000000000 --- a/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Tests for SSZ serialization of XMSS types.""" - -from lean_spec.subspecs.koalabear.field import P_BYTES -from lean_spec.subspecs.xmss.constants import TEST_CONFIG -from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature -from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME -from lean_spec.types import Bytes32, Slot, Uint64 - - -def test_public_key_ssz_roundtrip() -> None: - """Test that PublicKey can be SSZ serialized and deserialized.""" - # Generate a key pair - activation_slot = Slot(0) - num_active_slots = Uint64(32) - public_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots).public_key - - # Serialize to bytes using SSZ - pk_bytes = public_key.encode_bytes() - - # Deserialize from bytes - recovered_pk = PublicKey.decode_bytes(pk_bytes) - - # Verify the recovered public key matches the original - assert recovered_pk.root == public_key.root - assert recovered_pk.parameter == public_key.parameter - assert recovered_pk == public_key - - -def test_signature_ssz_roundtrip() -> None: - """Test that Signature can be SSZ serialized and deserialized.""" - # Generate a key pair and sign a message - activation_slot = Slot(0) - num_active_slots = Uint64(32) - kp = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) - public_key, secret_key = kp.public_key, kp.secret_key - - message = Bytes32(bytes([42] * 32)) - epoch = Slot(0) - signature = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - - # Serialize to bytes using SSZ - sig_bytes = signature.encode_bytes() - - # Deserialize from bytes - recovered_sig = Signature.decode_bytes(sig_bytes) - - # Verify the recovered signature matches the original - assert recovered_sig.path.siblings == signature.path.siblings - assert recovered_sig.rho == signature.rho - assert recovered_sig.hashes == signature.hashes - assert recovered_sig == signature - - # Verify the signature still verifies - assert TEST_SIGNATURE_SCHEME.verify(public_key, epoch, message, recovered_sig) - - -def test_secret_key_ssz_roundtrip() -> None: - """Test that SecretKey can be SSZ serialized and deserialized.""" - # Generate a key pair - activation_slot = Slot(0) - num_active_slots = Uint64(32) - kp = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) - public_key, secret_key = kp.public_key, kp.secret_key - - # Serialize to bytes using SSZ - sk_bytes = secret_key.encode_bytes() - - # Deserialize from bytes - recovered_sk = SecretKey.decode_bytes(sk_bytes) - - # Verify the recovered secret key matches the original - assert recovered_sk.prf_key == secret_key.prf_key - assert recovered_sk.parameter == secret_key.parameter - assert recovered_sk.activation_slot == secret_key.activation_slot - assert recovered_sk.num_active_slots == secret_key.num_active_slots - assert recovered_sk.top_tree == secret_key.top_tree - assert recovered_sk.left_bottom_tree_index == secret_key.left_bottom_tree_index - assert recovered_sk.left_bottom_tree == secret_key.left_bottom_tree - assert recovered_sk.right_bottom_tree == secret_key.right_bottom_tree - assert recovered_sk == secret_key - - # Verify the recovered secret key can still sign - message = Bytes32(bytes([99] * 32)) - epoch = Slot(1) - signature = TEST_SIGNATURE_SCHEME.sign(recovered_sk, epoch, message) - assert TEST_SIGNATURE_SCHEME.verify(public_key, epoch, message, signature) - - -def test_deterministic_serialization() -> None: - """Test that serialization is deterministic.""" - # Generate a key pair - activation_slot = Slot(0) - num_active_slots = Uint64(32) - kp = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) - public_key, secret_key = kp.public_key, kp.secret_key - - # Serialize multiple times - pk_bytes1 = public_key.encode_bytes() - pk_bytes2 = public_key.encode_bytes() - sk_bytes1 = secret_key.encode_bytes() - sk_bytes2 = secret_key.encode_bytes() - - # Verify serialization is deterministic - assert pk_bytes1 == pk_bytes2 - assert sk_bytes1 == sk_bytes2 - - # Sign a message multiple times with deterministic randomness - message = Bytes32(bytes([42] * 32)) - epoch = Slot(0) - sig1 = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - sig2 = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - - # Signatures should be identical (deterministic signing) - assert sig1 == sig2 - - sig_bytes1 = sig1.encode_bytes() - sig_bytes2 = sig2.encode_bytes() - assert sig_bytes1 == sig_bytes2 - - -def test_signature_size_matches_config() -> None: - """Verify SIGNATURE_LEN_BYTES matches actual SSZ-encoded size.""" - activation_slot = Slot(0) - num_active_slots = Uint64(32) - secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots).secret_key - - message = Bytes32(bytes([42] * 32)) - epoch = Slot(0) - signature = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - - encoded = signature.encode_bytes() - assert len(encoded) == TEST_CONFIG.SIGNATURE_LEN_BYTES - - -def test_public_key_size_matches_config() -> None: - """Verify the encoded public key size matches its SSZ shape. - - A public key is a HASH_LEN_FE-element hash digest plus a PARAMETER_LEN - public parameter, each field element packed into P_BYTES bytes. - """ - activation_slot = Slot(0) - num_active_slots = Uint64(32) - public_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots).public_key - - encoded = public_key.encode_bytes() - expected = (TEST_CONFIG.HASH_LEN_FE + TEST_CONFIG.PARAMETER_LEN) * P_BYTES - assert len(encoded) == expected diff --git a/tests/lean_spec/subspecs/xmss/test_types.py b/tests/lean_spec/subspecs/xmss/test_types.py new file mode 100644 index 000000000..36f78c90f --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_types.py @@ -0,0 +1,92 @@ +"""Tests for the base SSZ types of the XMSS signature scheme.""" + +import pytest + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.field import random_domain +from lean_spec.subspecs.xmss.types import ( + NODE_LIST_LIMIT, + ChainTweak, + HashDigestList, + HashDigestVector, + HashTreeOpening, + Parameter, + Randomness, + TreeTweak, +) +from lean_spec.types import Uint64 +from lean_spec.types.exceptions import SSZValueError + + +def test_tree_tweak_fields() -> None: + """A tree tweak stores its level and node index in order.""" + assert TreeTweak(level=3, index=Uint64(7)) == (3, Uint64(7)) + + +def test_chain_tweak_fields() -> None: + """A chain tweak stores its epoch, chain index, and step in order.""" + assert ChainTweak(epoch=Uint64(2), chain_index=5, step=1) == (Uint64(2), 5, 1) + + +def test_node_list_limit_is_twice_the_leaf_row() -> None: + """The sparse-layer cap is twice the widest bottom-tree leaf row.""" + assert NODE_LIST_LIMIT == 2 * TEST_CONFIG.LEAVES_PER_BOTTOM_TREE + + +def test_hash_digest_vector_length_is_digest_length() -> None: + """A digest vector holds exactly one Poseidon output worth of elements.""" + assert HashDigestVector.LENGTH == TEST_CONFIG.HASH_LEN_FE + + +def test_hash_digest_vector_accepts_exact_length() -> None: + """A digest vector of the configured length validates.""" + data = [Fp(value=i) for i in range(TEST_CONFIG.HASH_LEN_FE)] + assert HashDigestVector(data=data).data == tuple(data) + + +def test_hash_digest_vector_rejects_wrong_length() -> None: + """A digest vector of the wrong length fails validation.""" + with pytest.raises(SSZValueError): + HashDigestVector(data=[Fp(value=0)] * (TEST_CONFIG.HASH_LEN_FE + 1)) + + +def test_parameter_length_is_parameter_length() -> None: + """A parameter holds the configured number of personalization elements.""" + assert Parameter.LENGTH == TEST_CONFIG.PARAMETER_LEN + + +def test_randomness_length_is_randomness_length() -> None: + """The signing randomness holds the configured number of elements.""" + assert Randomness.LENGTH == TEST_CONFIG.RAND_LEN_FE + + +def test_hash_digest_list_limit_is_node_list_limit() -> None: + """The digest list cap matches the sparse-layer node limit.""" + assert HashDigestList.LIMIT == NODE_LIST_LIMIT + + +def test_hash_digest_list_accepts_limit_entries() -> None: + """A digest list filled to the cap validates.""" + nodes = [random_domain(TEST_CONFIG) for _ in range(NODE_LIST_LIMIT)] + assert len(HashDigestList(data=nodes).data) == NODE_LIST_LIMIT + + +def test_hash_digest_list_rejects_over_limit() -> None: + """A digest list one entry past the cap fails validation.""" + nodes = [random_domain(TEST_CONFIG) for _ in range(NODE_LIST_LIMIT + 1)] + with pytest.raises(SSZValueError): + HashDigestList(data=nodes) + + +def test_hash_tree_opening_roundtrips_through_ssz() -> None: + """An opening encodes and decodes back to an equal value.""" + siblings = [random_domain(TEST_CONFIG) for _ in range(3)] + opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) + assert HashTreeOpening.decode_bytes(opening.encode_bytes()) == opening + + +def test_hash_tree_opening_empty_is_allowed() -> None: + """An opening with no siblings is a valid empty path.""" + opening = HashTreeOpening(siblings=HashDigestList(data=[])) + assert len(opening.siblings) == 0 diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py deleted file mode 100644 index 3be470bf1..000000000 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Tests for the utility functions in the XMSS signature scheme.""" - -import secrets -from typing import List - -import pytest - -from lean_spec.subspecs.koalabear.field import Fp, P -from lean_spec.subspecs.xmss.constants import TEST_CONFIG -from lean_spec.subspecs.xmss.field import int_to_base_p -from lean_spec.subspecs.xmss.interface import _expand_activation_time -from lean_spec.subspecs.xmss.merkle import HashSubTree -from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON -from lean_spec.subspecs.xmss.prf import PRFKey -from lean_spec.subspecs.xmss.types import Parameter -from lean_spec.types import Uint64 - - -@pytest.mark.parametrize( - "value, num_limbs, expected_values", - [ - (0, 4, [0, 0, 0, 0]), - (123, 4, [123, 0, 0, 0]), - (P, 4, [0, 1, 0, 0]), - (P - 1, 4, [P - 1, 0, 0, 0]), - (3 * (P**2) + 2 * P + 1, 4, [1, 2, 3, 0]), - (P**3 - 1, 3, [P - 1, P - 1, P - 1]), - ], -) -def test_int_to_base_p(value: int, num_limbs: int, expected_values: List[int]) -> None: - """Validates the base-P decomposition of an integer with known-answer tests.""" - # Convert the list of expected integer values to a list of Fp objects for comparison. - expected_limbs = [Fp(value=v) for v in expected_values] - # Perform the decomposition. - actual_limbs = int_to_base_p(value, num_limbs) - # Assert that the result matches the expected output. - assert actual_limbs == expected_limbs - - -def test_int_to_base_p_roundtrip() -> None: - """Ensures that the base-P decomposition is perfectly reversible.""" - # Create a large, random multi-limb integer. - num_limbs = 5 - original_limbs = [secrets.randbelow(P) for _ in range(num_limbs)] - original_value = sum(val * (P**i) for i, val in enumerate(original_limbs)) - - # Decompose the integer into base-P limbs using the function under test. - decomposed_limbs_fp = int_to_base_p(original_value, num_limbs) - decomposed_limbs = [int(fp) for fp in decomposed_limbs_fp] - - # Reconstruct the integer from the decomposed limbs. - reconstructed_value = sum(val * (P**i) for i, val in enumerate(decomposed_limbs)) - - # Assert that the original and reconstructed values are identical. - assert original_value == reconstructed_value - # Also assert that the original and decomposed limbs match. - assert original_limbs == decomposed_limbs - - -@pytest.mark.parametrize( - "log_lifetime, desired_activation, desired_num, expected_start_tree, expected_end_tree", - [ - # Test case 1: Request falls on boundary, minimum duration - (8, 0, 16, 0, 2), # C = 16, requested [0, 16), aligned [0, 32) = 2 trees - # Test case 2: Request needs rounding - (8, 10, 5, 0, 2), # C = 16, requested [10, 15), aligned [0, 32) = 2 trees - # Test case 3: Larger request - (8, 0, 100, 0, 7), # C = 16, requested [0, 100), aligned [0, 112) = 7 trees - # Test case 4: Request that exceeds lifetime - (4, 0, 300, 0, 4), # C = 4, LIFETIME = 16, clamped to [0, 16) = 4 trees - # Test case 5: Request in middle - (8, 32, 16, 2, 4), # C = 16, requested [32, 48), aligned [32, 48) = 2 trees - ], -) -def test_expand_activation_time( - log_lifetime: int, - desired_activation: int, - desired_num: int, - expected_start_tree: int, - expected_end_tree: int, -) -> None: - """Tests that _expand_activation_time correctly aligns and expands activation intervals.""" - start_tree, end_tree = _expand_activation_time(log_lifetime, desired_activation, desired_num) - assert start_tree == expected_start_tree - assert end_tree == expected_end_tree - - # Verify minimum duration constraint (at least 2 bottom trees) - assert end_tree - start_tree >= 2 - - # Verify alignment - c = 1 << (log_lifetime // 2) - actual_start_slot = start_tree * c - actual_end_slot = end_tree * c - assert actual_start_slot % c == 0 - assert actual_end_slot % c == 0 - - # Verify it covers the desired range (if the desired range fits within lifetime) - lifetime = c * c - desired_end_slot = desired_activation + desired_num - if desired_end_slot <= lifetime: - assert actual_start_slot <= desired_activation - assert actual_end_slot >= desired_end_slot - else: - # If desired range exceeds lifetime, verify it's clamped to lifetime bounds - assert actual_start_slot >= 0 - assert actual_end_slot <= lifetime - - -def test_hash_subtree_from_prf_key() -> None: - """Tests that HashSubTree.from_prf_key generates a valid bottom tree.""" - config = TEST_CONFIG - - # Generate a PRF key - prf_key = PRFKey.generate() - - # Generate a random parameter - parameter = Parameter( - data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] - ) - - # Generate bottom tree 0 - bottom_tree = HashSubTree.from_prf_key( - poseidon=TEST_POSEIDON, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - # Verify structure - assert bottom_tree.depth == Uint64(config.LOG_LIFETIME) - assert bottom_tree.lowest_layer == Uint64(0) - assert len(bottom_tree.layers) > 0 - - # Verify the root layer has exactly one node - root_layer = bottom_tree.layers.data[-1] - assert len(root_layer.nodes) == 1 - - # Verify the leaf layer covers the right range - leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - leaf_layer = bottom_tree.layers.data[0] - assert len(leaf_layer.nodes) == leafs_per_bottom_tree - - -def test_hash_subtree_from_prf_key_deterministic() -> None: - """Tests that HashSubTree.from_prf_key is deterministic.""" - config = TEST_CONFIG - prf_key = PRFKey.generate() - parameter = Parameter( - data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] - ) - - # Generate the same bottom tree twice - tree1 = HashSubTree.from_prf_key( - poseidon=TEST_POSEIDON, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - tree2 = HashSubTree.from_prf_key( - poseidon=TEST_POSEIDON, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - # Verify the roots are identical - assert tree1.layers.data[-1].nodes[0] == tree2.layers.data[-1].nodes[0] - - -def test_hash_subtree_from_prf_key_different_indices() -> None: - """Tests that different bottom tree indices produce different trees.""" - config = TEST_CONFIG - prf_key = PRFKey.generate() - parameter = Parameter( - data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] - ) - - # Generate two different bottom trees - tree0 = HashSubTree.from_prf_key( - poseidon=TEST_POSEIDON, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - tree1 = HashSubTree.from_prf_key( - poseidon=TEST_POSEIDON, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(1), - parameter=parameter, - ) - - # Verify the roots are different - assert tree0.layers.data[-1].nodes[0] != tree1.layers.data[-1].nodes[0] From 1d5313c5f2015fcffeb14ba0f8962ed93e066651 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 27 May 2026 21:56:53 +0200 Subject: [PATCH 9/9] test(xmss): make the xmss test directory a package The new xmss test module mirroring the field source shares a basename with the koalabear field test. Without a package marker, pytest's prepend import mode maps both to the same module name and aborts collection. Adding an empty package init namespaces the xmss test modules and resolves the clash, matching how other test directories in the tree are already packaged. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/lean_spec/subspecs/xmss/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/lean_spec/subspecs/xmss/__init__.py diff --git a/tests/lean_spec/subspecs/xmss/__init__.py b/tests/lean_spec/subspecs/xmss/__init__.py new file mode 100644 index 000000000..e69de29bb