From a7d758f45c518a6590a7c7dd2f71330d55ef5432 Mon Sep 17 00:00:00 2001 From: defiant1708 Date: Tue, 19 Aug 2025 23:03:47 +0900 Subject: [PATCH 1/5] feat(certificates): add authhttp client, certificate utils and crypto primitives (AES-CBC/GCM, legacy helpers) --- bsv/aes_gcm.py | 61 +++++++ bsv/auth/clients/authhttp.py | 295 +++++++++++++++++++++++++++++++++ bsv/auth/utils.py | 201 +++++++++++++++++++++++ bsv/primitives/aescbc.py | 42 +++++ bsv/utils/legacy.py | 306 +++++++++++++++++++++++++++++++++++ 5 files changed, 905 insertions(+) create mode 100644 bsv/aes_gcm.py create mode 100644 bsv/auth/clients/authhttp.py create mode 100644 bsv/auth/utils.py create mode 100644 bsv/primitives/aescbc.py create mode 100644 bsv/utils/legacy.py diff --git a/bsv/aes_gcm.py b/bsv/aes_gcm.py new file mode 100644 index 0000000..9f0b9ce --- /dev/null +++ b/bsv/aes_gcm.py @@ -0,0 +1,61 @@ +from Cryptodome.Cipher import AES +from Cryptodome.Util import Padding + +class AESGCMError(Exception): + pass + +def aes_gcm_encrypt(plaintext: bytes, key: bytes, iv: bytes, aad: bytes = b""): + cipher = AES.new(key, AES.MODE_GCM, nonce=iv) + cipher.update(aad) + ciphertext, tag = cipher.encrypt_and_digest(plaintext) + return ciphertext, tag + +def aes_gcm_decrypt(ciphertext: bytes, key: bytes, iv: bytes, tag: bytes, aad: bytes = b""): + cipher = AES.new(key, AES.MODE_GCM, nonce=iv) + cipher.update(aad) + try: + plaintext = cipher.decrypt_and_verify(ciphertext, tag) + return plaintext + except ValueError as e: + raise AESGCMError(f"decryption failed: {e}") + +# --- GHASH utilities (for test vector compatibility, optional) --- +def xor_bytes(a: bytes, b: bytes) -> bytes: + return bytes(x ^ y for x, y in zip(a, b)) + +def right_shift(block: bytes) -> bytes: + b = bytearray(block) + carry = 0 + for i in range(len(b)): + old_carry = carry + carry = b[i] & 0x01 + b[i] >>= 1 + if old_carry: + b[i] |= 0x80 + return bytes(b) + +def check_bit(block: bytes, index: int, bit: int) -> bool: + return ((block[index] >> bit) & 1) == 1 + +def multiply(block0: bytes, block1: bytes) -> bytes: + v = bytearray(block1) + z = bytearray(16) + r = bytearray([0xe1] + [0x00]*15) + for i in range(16): + for j in range(7, -1, -1): + if check_bit(block0, i, j): + z = bytearray(x ^ y for x, y in zip(z, v)) + if check_bit(v, 15, 0): + v = bytearray(x ^ y for x, y in zip(right_shift(v), r)) + else: + v = bytearray(right_shift(v)) + return bytes(z) + +def ghash(input_bytes: bytes, hash_subkey: bytes) -> bytes: + result = bytes(16) + for i in range(0, len(input_bytes), 16): + block = input_bytes[i:i+16] + if len(block) < 16: + block = block + b"\x00" * (16 - len(block)) + result = multiply(xor_bytes(result, block), hash_subkey) + return result diff --git a/bsv/auth/clients/authhttp.py b/bsv/auth/clients/authhttp.py new file mode 100644 index 0000000..eea3296 --- /dev/null +++ b/bsv/auth/clients/authhttp.py @@ -0,0 +1,295 @@ +import threading +from typing import Any, Callable, Dict, Optional, List +import logging +import base64 +import os +import time +import urllib.parse +import requests + +from ..auth.peer import Peer +from ..auth.session_manager import DefaultSessionManager +from ..auth.requested_certificate_set import RequestedCertificateSet +from ..auth.verifiable_certificate import VerifiableCertificate +from ..auth.transports.simplified_http_transport import SimplifiedHTTPTransport +# from ...wallet.WalletInterface import WalletInterface + +class SimplifiedFetchRequestOptions: + def __init__(self, method: str = "GET", headers: Optional[Dict[str, str]] = None, body: Optional[bytes] = None, retry_counter: Optional[int] = None): + self.method = method + self.headers = headers or {} + self.body = body + self.retry_counter = retry_counter + +class AuthPeer: + def __init__(self): + self.peer = None # type: Optional[Peer] + self.identity_key = "" + self.supports_mutual_auth = None # type: Optional[bool] + self.pending_certificate_requests: List[bool] = [] + +class AuthFetch: + def __init__(self, wallet, requested_certs, session_manager=None): + if session_manager is None: + session_manager = DefaultSessionManager() + self.session_manager = session_manager + self.wallet = wallet + self.callbacks = {} # type: Dict[str, Dict[str, Callable]] + self.certificates_received = [] # type: List[VerifiableCertificate] + self.requested_certificates = requested_certs + self.peers = {} # type: Dict[str, AuthPeer] + self.logger = logging.getLogger("AuthHTTP") + + def fetch(self, ctx: Any, url_str: str, config: Optional[SimplifiedFetchRequestOptions] = None): + if config is None: + config = SimplifiedFetchRequestOptions() + # Handle retry counter + if config.retry_counter is not None: + if config.retry_counter <= 0: + raise Exception("request failed after maximum number of retries") + config.retry_counter -= 1 + # Extract base URL + parsed_url = urllib.parse.urlparse(url_str) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + # Create peer if needed + if base_url not in self.peers: + transport = SimplifiedHTTPTransport(base_url) + peer = Peer( + wallet=self.wallet, + transport=transport, + certificates_to_request=self.requested_certificates, + session_manager=self.session_manager + ) + auth_peer = AuthPeer() + auth_peer.peer = peer + self.peers[base_url] = auth_peer + # Set up certificate received/requested listeners(省略: 必要に応じて追加) + peer_to_use = self.peers[base_url] + # Generate request nonce + request_nonce = os.urandom(32) + request_nonce_b64 = base64.b64encode(request_nonce).decode() + # Serialize request + request_data = self.serialize_request( + config.method, + config.headers, + config.body or b"", + parsed_url, + request_nonce + ) + # コールバック用イベントと結果格納 + response_event = threading.Event() + response_holder = {'resp': None, 'err': None} + # コールバック登録 + self.callbacks[request_nonce_b64] = { + 'resolve': lambda resp: (response_holder.update({'resp': resp}), response_event.set()), + 'reject': lambda err: (response_holder.update({'err': err}), response_event.set()), + } + # Peerのgeneral messageリスナー登録 + def on_general_message(sender_public_key, payload): + # 先頭32バイトがresponse_nonce + if not payload or len(payload) < 32: + return + response_nonce = payload[:32] + response_nonce_b64 = base64.b64encode(response_nonce).decode() + if response_nonce_b64 != request_nonce_b64: + return # 自分のリクエストでなければ無視 + # 以降はHTTPレスポンスのデシリアライズ等(省略: 必要に応じて実装) + self.callbacks[request_nonce_b64]['resolve'](payload) + listener_id = peer_to_use.peer.listen_for_general_messages(on_general_message) + try: + # Peer経由で送信(ToPeer相当) + err = peer_to_use.peer.to_peer(ctx, request_data, None, 30000) + if err: + self.callbacks[request_nonce_b64]['reject'](err) + except Exception as e: + self.callbacks[request_nonce_b64]['reject'](e) + # レスポンス待機(またはタイムアウト) + response_event.wait(timeout=30) # 30秒タイムアウト + # コールバック解除 + peer_to_use.peer.stop_listening_for_general_messages(listener_id) + self.callbacks.pop(request_nonce_b64, None) + # 結果返却 + if response_holder['err']: + raise Exception(response_holder['err']) + return response_holder['resp'] + + def send_certificate_request(self, ctx: Any, base_url: str, certificates_to_request): + """ + GoのSendCertificateRequest相当: Peer経由で証明書リクエストを送り、受信まで待機。 + """ + parsed_url = urllib.parse.urlparse(base_url) + base_url_str = f"{parsed_url.scheme}://{parsed_url.netloc}" + if base_url_str not in self.peers: + transport = SimplifiedHTTPTransport(base_url_str) + peer = Peer( + wallet=self.wallet, + transport=transport, + certificates_to_request=self.requested_certificates, + session_manager=self.session_manager + ) + auth_peer = AuthPeer() + auth_peer.peer = peer + self.peers[base_url_str] = auth_peer + peer_to_use = self.peers[base_url_str] + # コールバック用イベントと結果格納 + cert_event = threading.Event() + cert_holder = {'certs': None, 'err': None} + def on_certificates_received(sender_public_key, certs): + cert_holder['certs'] = certs + cert_event.set() + callback_id = peer_to_use.peer.listen_for_certificates_received(on_certificates_received) + try: + err = peer_to_use.peer.request_certificates(ctx, None, certificates_to_request, 30000) + if err: + cert_holder['err'] = err + cert_event.set() + except Exception as e: + cert_holder['err'] = e + cert_event.set() + cert_event.wait(timeout=30) + peer_to_use.peer.stop_listening_for_certificates_received(callback_id) + if cert_holder['err']: + raise Exception(cert_holder['err']) + return cert_holder['certs'] + + def consume_received_certificates(self): + certs = self.certificates_received + self.certificates_received = [] + return certs + + def serialize_request(self, method: str, headers: Dict[str, str], body: bytes, parsed_url, request_nonce: bytes): + """ + GoのserializeRequestメソッドをPythonで再現。 + - method, headers, body, parsed_url, request_nonceをバイナリで直列化 + - ヘッダーはx-bsv-*系やcontent-type, authorizationのみ含める + - Goのutil.NewWriter/WriteVarInt相当はbytearray+独自関数で実装 + """ + import struct + import math + from collections import OrderedDict + + def write_varint(writer: bytearray, value: int): + # Bitcoin style varint (for simplicity, 8byte unsigned) + writer += struct.pack(' bool: + """ + Verifies that a nonce was derived from the given wallet. + Ported from Go/TypeScript verifyNonce. + """ + try: + nonce_bytes = base64.b64decode(nonce) + except Exception: + return False + if len(nonce_bytes) <= 16: + return False + data = nonce_bytes[:16] + hmac = nonce_bytes[16:] + # Prepare encryption_args for wallet.verify_hmac + encryption_args = { + 'protocol_id': { + 'securityLevel': 1, # Go版: SecurityLevelEveryApp = 1 + 'protocol': 'server hmac' + }, + 'key_id': data.decode('latin1'), # Go版: string(randomBytes) + 'counterparty': counterparty + } + args = { + 'encryption_args': encryption_args, + 'data': data, + 'hmac': hmac + } + try: + result = wallet.verify_hmac(ctx, args, "") + print(f"[verify_nonce] result={result}") + if isinstance(result, dict): + return bool(result.get('valid', False)) + else: + return bool(getattr(result, 'valid', False)) + except Exception: + return False + +def create_nonce(wallet: Any, counterparty: Any = None, ctx: Any = None) -> str: + """ + Creates a nonce derived from a wallet (ported from TypeScript createNonce). + """ + # Generate 16 random bytes for the first half of the data + first_half = os.urandom(16) + # Create an sha256 HMAC + encryption_args = { + 'protocol_id': { + 'securityLevel': 1, # Go版: SecurityLevelEveryApp = 1 + 'protocol': 'server hmac' + }, + 'key_id': first_half.decode('latin1'), # Go版: string(randomBytes) + 'counterparty': counterparty + } + args = { + 'encryption_args': encryption_args, + 'data': first_half + } + result = wallet.create_hmac(ctx, args, "") + print(f"[create_nonce] result={result}") + hmac = result.get('hmac') if isinstance(result, dict) else getattr(result, 'hmac', None) + if hmac is None: + raise Exception('Failed to create HMAC for nonce') + nonce_bytes = first_half + hmac + return base64.b64encode(nonce_bytes).decode('ascii') + + +def get_verifiable_certificates(wallet, requested_certificates, verifier_identity_key): + """ + Retrieves an array of verifiable certificates based on the request (ported from TypeScript getVerifiableCertificates). + """ + # Find matching certificates we have + matching = wallet.list_certificates({ + 'certifiers': requested_certificates.get('certifiers', []), + 'types': list(requested_certificates.get('types', {}).keys()) + }) + certificates = matching.get('certificates', []) + result = [] + for certificate in certificates: + proof = wallet.prove_certificate({ + 'certificate': certificate, + 'fields_to_reveal': requested_certificates['types'].get(certificate['type'], []), + 'verifier': verifier_identity_key + }) + # Construct VerifiableCertificate (assume similar constructor as TS) + from bsv.auth.verifiable_certificate import VerifiableCertificate + verifiable = VerifiableCertificate( + certificate['type'], + certificate['serialNumber'], + certificate['subject'], + certificate['certifier'], + certificate['revocationOutpoint'], + certificate['fields'], + proof.get('keyring_for_verifier', {}), + certificate['signature'] + ) + result.append(verifiable) + return result + + +def validate_certificates(verifier_wallet, message, certificates_requested=None): + """ + Validates and processes certificates received from a peer. + - Ensures each certificate's subject equals message.identityKey + - Verifies signature + - If certificates_requested is provided, enforces certifier/type/required fields + - Attempts to decrypt fields using the verifier wallet + Raises Exception on validation failure. + """ + from bsv.auth.verifiable_certificate import VerifiableCertificate + + certificates = getattr(message, 'certificates', None) or (message.get('certificates', None) if isinstance(message, dict) else None) + identity_key = getattr(message, 'identityKey', None) or (message.get('identityKey', None) if isinstance(message, dict) else None) + if not certificates: + raise Exception('No certificates were provided in the AuthMessage.') + if identity_key is None: + raise Exception('identityKey must be provided in the AuthMessage.') + + # Normalize certificates_requested into (allowed_certifiers, requested_types_map) + def _normalize_requested(req): + allowed_certifiers = [] + requested_types = {} + if req is None: + return allowed_certifiers, requested_types + try: + # RequestedCertificateSet + from bsv.auth.requested_certificate_set import RequestedCertificateSet + if isinstance(req, RequestedCertificateSet): + allowed_certifiers = list(getattr(req, 'certifiers', []) or []) + # For utils we expect plain string type keys; convert bytes keys to base64 strings + mapping = getattr(getattr(req, 'certificate_types', None), 'mapping', {}) or {} + requested_types = {base64.b64encode(k).decode('ascii'): list(v or []) for k, v in mapping.items()} + return allowed_certifiers, requested_types + except Exception: + pass + # dict-like + if isinstance(req, dict): + allowed_certifiers = req.get('certifiers') or req.get('Certifiers') or [] + types_dict = req.get('certificate_types') or req.get('certificateTypes') or req.get('types') or {} + # In utils tests, type keys are simple strings. Keep as-is. + for k, v in types_dict.items(): + requested_types[str(k)] = list(v or []) + return allowed_certifiers, requested_types + + allowed_certifiers, requested_types = _normalize_requested(certificates_requested) + + for incoming in certificates: + # Extract fields as-is (tests expect plain strings, not decoded keys) + cert_type = incoming.get('type') + serial_number = incoming.get('serialNumber') or incoming.get('serial_number') + subject = incoming.get('subject') + certifier = incoming.get('certifier') + fields = incoming.get('fields') or {} + signature = incoming.get('signature') + keyring = incoming.get('keyring') or {} + + if subject != identity_key: + raise Exception(f'The subject of one of your certificates ("{subject}") is not the same as the request sender ("{identity_key}").') + + # Instantiate VerifiableCertificate with backwards-compatible signature used in tests + try: + vc = VerifiableCertificate(cert_type, serial_number, subject, certifier, incoming.get('revocationOutpoint'), fields, keyring, signature) + except Exception: + # Fallback: if real class is present, try wrapping via real constructor + try: + from bsv.auth.certificate import Certificate as _Cert, Outpoint as _Out + from bsv.keys import PublicKey as _PK + subj_pk = _PK(subject) + cert_pk = _PK(certifier) if certifier else None + rev = incoming.get('revocationOutpoint') + rev_out = None + if isinstance(rev, dict): + txid = rev.get('txid') or rev.get('txID') or rev.get('txId') + index = rev.get('index') or rev.get('vout') + if txid is not None and index is not None: + rev_out = _Out(txid, int(index)) + base = _Cert(cert_type, serial_number, subj_pk, cert_pk, rev_out, fields, signature) + vc = VerifiableCertificate(base, keyring) + except Exception as e: + raise e + + # Signature verification + if not vc.verify(): + raise Exception(f'The signature for the certificate with serial number {serial_number} is invalid!') + + # Requested constraints + if allowed_certifiers or requested_types: + if allowed_certifiers and certifier not in allowed_certifiers: + raise Exception(f'Certificate with serial number {serial_number} has an unrequested certifier') + if requested_types and cert_type not in requested_types: + raise Exception(f'Certificate with type {cert_type} was not requested') + required_fields = requested_types.get(cert_type, []) + for field in required_fields: + if field not in (fields or {}): + raise Exception(f'Certificate missing required field: {field}') + + # Try to decrypt fields for the verifier + # Let decryption errors bubble up to the caller (as tests expect) + vc.decrypt_fields(None, verifier_wallet) \ No newline at end of file diff --git a/bsv/primitives/aescbc.py b/bsv/primitives/aescbc.py new file mode 100644 index 0000000..f74af86 --- /dev/null +++ b/bsv/primitives/aescbc.py @@ -0,0 +1,42 @@ +from Cryptodome.Cipher import AES + +class InvalidPadding(Exception): + pass + +def PKCS7Padd(data: bytes, block_size: int) -> bytes: + padding = block_size - (len(data) % block_size) + return data + bytes([padding]) * padding + +def PKCS7Unpad(data: bytes, block_size: int) -> bytes: + length = len(data) + if length % block_size != 0 or length == 0: + raise InvalidPadding("invalid padding length") + padding = data[-1] + if padding > block_size: + raise InvalidPadding("invalid padding byte (large)") + if not all(x == padding for x in data[-padding:]): + raise InvalidPadding("invalid padding byte (inconsistent)") + return data[:-padding] + +def AESCBCEncrypt(data: bytes, key: bytes, iv: bytes, concat_iv: bool) -> bytes: + block_size = AES.block_size + padded = PKCS7Padd(data, block_size) + cipher = AES.new(key, AES.MODE_CBC, iv) + ciphertext = cipher.encrypt(padded) + if concat_iv: + return iv + ciphertext + return ciphertext + +def AESCBCDecrypt(data: bytes, key: bytes, iv: bytes) -> bytes: + block_size = AES.block_size + cipher = AES.new(key, AES.MODE_CBC, iv) + plaintext = cipher.decrypt(data) + return PKCS7Unpad(plaintext, block_size) + +def aes_encrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: + # 既存のAESCBCEncryptの引数順に合わせてラップ + return AESCBCEncrypt(data, key, iv, concat_iv=False) + +def aes_decrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: + # 既存のAESCBCDecryptの引数順に合わせてラップ + return AESCBCDecrypt(data, key, iv) diff --git a/bsv/utils/legacy.py b/bsv/utils/legacy.py new file mode 100644 index 0000000..f0488e3 --- /dev/null +++ b/bsv/utils/legacy.py @@ -0,0 +1,306 @@ +""" +Legacy utility functions from the main utils.py module. +This module provides a clean interface to functions that were originally in utils.py. +""" + +import math +import re +import struct +from base64 import b64encode, b64decode +from contextlib import suppress +from typing import Tuple, Optional, Union, Literal, List + +from ..base58 import base58check_decode +from ..constants import Network, ADDRESS_PREFIX_NETWORK_DICT, WIF_PREFIX_NETWORK_DICT, NUMBER_BYTE_LENGTH +from ..constants import OpCode +from ..curve import curve + + +def decode_wif(wif: str) -> Tuple[bytes, bool, Network]: + """ + Decode WIF (Wallet Import Format) string to private key bytes. + + Args: + wif: WIF string to decode + + Returns: + Tuple of (private_key_bytes, compressed, network) + + Raises: + ValueError: If WIF format is invalid + """ + decoded = base58check_decode(wif) + prefix = decoded[:1] + network = WIF_PREFIX_NETWORK_DICT.get(prefix) + if not network: + raise ValueError(f'unknown WIF prefix {prefix.hex()}') + if len(wif) == 52 and decoded[-1] == 1: + return decoded[1:-1], True, network + return decoded[1:], False, network + + +def address_to_public_key_hash(address: str) -> bytes: + """ + Convert P2PKH address to the corresponding public key hash. + + Args: + address: Bitcoin address string + + Returns: + Public key hash bytes + + Raises: + ValueError: If address format is invalid + """ + if not re.match(r'^[1mn][a-km-zA-HJ-NP-Z1-9]{24,33}$', address): + raise ValueError(f'invalid P2PKH address {address}') + decoded = base58check_decode(address) + return decoded[1:] + + +def text_digest(text: str) -> bytes: + """ + Create digest for signing arbitrary text with bitcoin private key. + + Args: + text: Text to create digest for + + Returns: + Digest bytes ready for signing + """ + def serialize_text(text: str) -> bytes: + message: bytes = text.encode('utf-8') + return unsigned_to_varint(len(message)) + message + + return serialize_text('Bitcoin Signed Message:\n') + serialize_text(text) + + +def unsigned_to_varint(num: int) -> bytes: + """ + Convert unsigned integer to variable length integer. + + Args: + num: Integer to encode (0 to 2^64-1) + + Returns: + Varint encoded bytes + + Raises: + OverflowError: If number is out of valid range + """ + if num < 0 or num > 0xffffffffffffffff: + raise OverflowError(f"can't convert {num} to varint") + if num <= 0xfc: + return num.to_bytes(1, 'little') + elif num <= 0xffff: + return b'\xfd' + num.to_bytes(2, 'little') + elif num <= 0xffffffff: + return b'\xfe' + num.to_bytes(4, 'little') + else: + return b'\xff' + num.to_bytes(8, 'little') + + +def deserialize_ecdsa_recoverable(signature: bytes) -> Tuple[int, int, int]: + """ + Deserialize recoverable ECDSA signature from bytes to (r, s, recovery_id). + + Args: + signature: 65-byte signature (r + s + recovery_id) + + Returns: + Tuple of (r, s, recovery_id) + + Raises: + AssertionError: If signature format is invalid + """ + assert len(signature) == 65, 'invalid length of recoverable ECDSA signature' + rec_id = signature[-1] + assert 0 <= rec_id <= 3, f'invalid recovery id {rec_id}' + r = int.from_bytes(signature[:32], 'big') + s = int.from_bytes(signature[32:-1], 'big') + return r, s, rec_id + + +def serialize_ecdsa_recoverable(signature: Tuple[int, int, int]) -> bytes: + """ + Serialize recoverable ECDSA signature from (r, s, recovery_id) to 65-byte form. + """ + r, s, rec_id = signature + assert 0 <= rec_id <= 3, f'invalid recovery id {rec_id}' + r_bytes = int(r).to_bytes(32, 'big') + s_bytes = int(s).to_bytes(32, 'big') + return r_bytes + s_bytes + int(rec_id).to_bytes(1, 'big') + + +def serialize_ecdsa_der(signature: Tuple[int, int]) -> bytes: + """ + Serialize ECDSA signature (r, s) to bitcoin strict DER format. + + Args: + signature: Tuple of (r, s) integers + + Returns: + DER encoded signature bytes + """ + r, s = signature + # Enforce low s value + if s > curve.n // 2: + s = curve.n - s + + # Encode r + r_bytes = r.to_bytes(32, 'big').lstrip(b'\x00') + if r_bytes[0] & 0x80: + r_bytes = b'\x00' + r_bytes + serialized = bytes([2, len(r_bytes)]) + r_bytes + + # Encode s + s_bytes = s.to_bytes(32, 'big').lstrip(b'\x00') + if s_bytes[0] & 0x80: + s_bytes = b'\x00' + s_bytes + serialized += bytes([2, len(s_bytes)]) + s_bytes + + return bytes([0x30, len(serialized)]) + serialized + + +def deserialize_ecdsa_der(signature: bytes) -> Tuple[int, int]: + """ + Deserialize ECDSA signature from bitcoin strict DER to (r, s). + + Args: + signature: DER-encoded ECDSA signature bytes + + Returns: + Tuple of integers (r, s) + + Raises: + ValueError: If signature encoding is invalid + """ + try: + assert signature[0] == 0x30 + assert int(signature[1]) == len(signature) - 2 + # r + assert signature[2] == 0x02 + r_len = int(signature[3]) + r = int.from_bytes(signature[4: 4 + r_len], 'big') + # s + assert signature[4 + r_len] == 0x02 + s_len = int(signature[5 + r_len]) + s = int.from_bytes(signature[-s_len:], 'big') + return r, s + except Exception: + raise ValueError(f'invalid DER encoded {signature.hex()}') + + +def stringify_ecdsa_recoverable(signature: bytes, compressed: bool = True) -> str: + """ + Stringify recoverable ECDSA signature to base64 format. + + Args: + signature: 65-byte recoverable signature + compressed: Whether public key is compressed + + Returns: + Base64 encoded signature string + """ + r, s, recovery_id = deserialize_ecdsa_recoverable(signature) + prefix: int = 27 + recovery_id + (4 if compressed else 0) + signature_bytes: bytes = prefix.to_bytes(1, 'big') + signature[:-1] + return b64encode(signature_bytes).decode('ascii') + + +def unstringify_ecdsa_recoverable(signature: str) -> Tuple[bytes, bool]: + """ + Unstringify recoverable ECDSA signature from base64 format. + + Args: + signature: Base64 encoded signature string + + Returns: + Tuple of (signature_bytes, was_compressed) + """ + serialized = b64decode(signature) + assert len(serialized) == 65, 'invalid length of recoverable ECDSA signature' + prefix = serialized[0] + assert 27 <= prefix < 35, f'invalid recoverable ECDSA signature prefix {prefix}' + + compressed = False + if prefix >= 31: + compressed = True + prefix -= 4 + recovery_id = prefix - 27 + return serialized[1:] + recovery_id.to_bytes(1, 'big'), compressed + + +def encode_int(num: int) -> bytes: + """ + Encode signed integer for bitcoin script push operation. + + Args: + num: Integer to encode + + Returns: + Encoded bytes ready for script + """ + if num == 0: + return OpCode.OP_0 + + negative: bool = num < 0 + octets: bytearray = bytearray(unsigned_to_bytes(-num if negative else num, 'little')) + if octets[-1] & 0x80: + octets += b'\x00' + if negative: + octets[-1] |= 0x80 + + # Import encode_pushdata from the utils package + from .pushdata import encode_pushdata + return encode_pushdata(octets) + + +def unsigned_to_bytes(num: int, byteorder: Literal['big', 'little'] = 'big') -> bytes: + """ + Convert unsigned integer to minimum number of bytes. + + Args: + num: Integer to convert + byteorder: Byte order ('big' or 'little') + + Returns: + Bytes representation + """ + if num < 0: + raise OverflowError(f"can't convert negative number {num} to bytes") + return num.to_bytes(math.ceil(num.bit_length() / 8) or 1, byteorder) + + +def to_bytes(msg: Union[bytes, str, List[int]], enc: Optional[str] = None) -> bytes: + """ + Convert various message formats into a bytes object. + + - If msg is bytes, return as-is + - If msg is str and enc == 'hex', parse hex string (len odd handled) + - If msg is str and enc == 'base64', decode base64 + - If msg is str and enc is None, UTF-8 encode + - If msg is a list of ints, convert to bytes + - If msg is falsy, return empty bytes + """ + if isinstance(msg, bytes): + return msg + if not msg: + return bytes() + if isinstance(msg, str): + if enc == 'hex': + cleaned = ''.join(filter(str.isalnum, msg)) + if len(cleaned) % 2 != 0: + cleaned = '0' + cleaned + return bytes(int(cleaned[i:i + 2], 16) for i in range(0, len(cleaned), 2)) + if enc == 'base64': + return b64decode(msg) + return msg.encode('utf-8') + return bytes(msg) + + +def reverse_hex_byte_order(hex_str: str) -> str: + """ + Reverse the byte order of a hex string (little-endian <-> big-endian view). + """ + return bytes.fromhex(hex_str)[::-1].hex() From fc1958b537a4b3a0c92715f8071db943baf0119b Mon Sep 17 00:00:00 2001 From: defiant1708 Date: Wed, 20 Aug 2025 00:00:44 +0900 Subject: [PATCH 2/5] refactor(auth): update authhttp client and utils --- bsv/auth/clients/authhttp.py | 21 +++++++++++---------- bsv/auth/utils.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/bsv/auth/clients/authhttp.py b/bsv/auth/clients/authhttp.py index eea3296..869dcb0 100644 --- a/bsv/auth/clients/authhttp.py +++ b/bsv/auth/clients/authhttp.py @@ -6,6 +6,7 @@ import time import urllib.parse import requests +from requests.exceptions import RetryError, HTTPError from ..auth.peer import Peer from ..auth.session_manager import DefaultSessionManager @@ -46,7 +47,7 @@ def fetch(self, ctx: Any, url_str: str, config: Optional[SimplifiedFetchRequestO # Handle retry counter if config.retry_counter is not None: if config.retry_counter <= 0: - raise Exception("request failed after maximum number of retries") + raise RetryError("request failed after maximum number of retries") config.retry_counter -= 1 # Extract base URL parsed_url = urllib.parse.urlparse(url_str) @@ -110,7 +111,7 @@ def on_general_message(sender_public_key, payload): self.callbacks.pop(request_nonce_b64, None) # 結果返却 if response_holder['err']: - raise Exception(response_holder['err']) + raise RuntimeError(response_holder['err']) return response_holder['resp'] def send_certificate_request(self, ctx: Any, base_url: str, certificates_to_request): @@ -149,7 +150,7 @@ def on_certificates_received(sender_public_key, certs): cert_event.wait(timeout=30) peer_to_use.peer.stop_listening_for_certificates_received(callback_id) if cert_holder['err']: - raise Exception(cert_holder['err']) + raise RuntimeError(cert_holder['err']) return cert_holder['certs'] def consume_received_certificates(self): @@ -244,12 +245,12 @@ def handle_fetch_and_validate(self, url_str: str, config: SimplifiedFetchRequest for k in resp.headers: k_lower = k.lower() if k_lower == "x-bsv-auth-identity-key" or k_lower.startswith("x-bsv-auth"): - raise Exception("the server is trying to claim it has been authenticated when it has not") + raise PermissionError("the server is trying to claim it has been authenticated when it has not") # 成功時はmutual auth非対応を記録 if resp.status_code < 400: peer_to_use.supports_mutual_auth = False return resp - raise Exception(f"request failed with status: {resp.status_code}") + raise HTTPError(f"request failed with status: {resp.status_code}") def handle_payment_and_retry(self, ctx: Any, url_str: str, config: SimplifiedFetchRequestOptions, original_response): """ @@ -258,19 +259,19 @@ def handle_payment_and_retry(self, ctx: Any, url_str: str, config: SimplifiedFet # 必要なヘッダー取得 payment_version = original_response.headers.get("x-bsv-payment-version") if not payment_version or payment_version != "1.0": - raise Exception(f"unsupported x-bsv-payment-version response header. Client version: 1.0, Server version: {payment_version}") + raise ValueError(f"unsupported x-bsv-payment-version response header. Client version: 1.0, Server version: {payment_version}") satoshis_required = original_response.headers.get("x-bsv-payment-satoshis-required") if not satoshis_required: - raise Exception("missing x-bsv-payment-satoshis-required response header") + raise ValueError("missing x-bsv-payment-satoshis-required response header") satoshis_required = int(satoshis_required) if satoshis_required <= 0: - raise Exception("invalid x-bsv-payment-satoshis-required response header value") + raise ValueError("invalid x-bsv-payment-satoshis-required response header value") server_identity_key = original_response.headers.get("x-bsv-auth-identity-key") if not server_identity_key: - raise Exception("missing x-bsv-auth-identity-key response header") + raise ValueError("missing x-bsv-auth-identity-key response header") derivation_prefix = original_response.headers.get("x-bsv-payment-derivation-prefix") if not derivation_prefix: - raise Exception("missing x-bsv-payment-derivation-prefix response header") + raise ValueError("missing x-bsv-payment-derivation-prefix response header") # ノンス生成(Goのutils.CreateNonce相当: ここではランダム文字列) derivation_suffix = base64.b64encode(os.urandom(8)).decode() # 公開鍵取得(Goのec.PublicKeyFromString相当: 省略) diff --git a/bsv/auth/utils.py b/bsv/auth/utils.py index f25ade1..07225f5 100644 --- a/bsv/auth/utils.py +++ b/bsv/auth/utils.py @@ -62,7 +62,7 @@ def create_nonce(wallet: Any, counterparty: Any = None, ctx: Any = None) -> str: print(f"[create_nonce] result={result}") hmac = result.get('hmac') if isinstance(result, dict) else getattr(result, 'hmac', None) if hmac is None: - raise Exception('Failed to create HMAC for nonce') + raise RuntimeError('Failed to create HMAC for nonce') nonce_bytes = first_half + hmac return base64.b64encode(nonce_bytes).decode('ascii') @@ -114,9 +114,9 @@ def validate_certificates(verifier_wallet, message, certificates_requested=None) certificates = getattr(message, 'certificates', None) or (message.get('certificates', None) if isinstance(message, dict) else None) identity_key = getattr(message, 'identityKey', None) or (message.get('identityKey', None) if isinstance(message, dict) else None) if not certificates: - raise Exception('No certificates were provided in the AuthMessage.') + raise ValueError('No certificates were provided in the AuthMessage.') if identity_key is None: - raise Exception('identityKey must be provided in the AuthMessage.') + raise ValueError('identityKey must be provided in the AuthMessage.') # Normalize certificates_requested into (allowed_certifiers, requested_types_map) def _normalize_requested(req): @@ -157,7 +157,7 @@ def _normalize_requested(req): keyring = incoming.get('keyring') or {} if subject != identity_key: - raise Exception(f'The subject of one of your certificates ("{subject}") is not the same as the request sender ("{identity_key}").') + raise ValueError(f'The subject of one of your certificates ("{subject}") is not the same as the request sender ("{identity_key}").') # Instantiate VerifiableCertificate with backwards-compatible signature used in tests try: @@ -183,18 +183,18 @@ def _normalize_requested(req): # Signature verification if not vc.verify(): - raise Exception(f'The signature for the certificate with serial number {serial_number} is invalid!') + raise ValueError(f'The signature for the certificate with serial number {serial_number} is invalid!') # Requested constraints if allowed_certifiers or requested_types: if allowed_certifiers and certifier not in allowed_certifiers: - raise Exception(f'Certificate with serial number {serial_number} has an unrequested certifier') + raise ValueError(f'Certificate with serial number {serial_number} has an unrequested certifier') if requested_types and cert_type not in requested_types: - raise Exception(f'Certificate with type {cert_type} was not requested') + raise ValueError(f'Certificate with type {cert_type} was not requested') required_fields = requested_types.get(cert_type, []) for field in required_fields: if field not in (fields or {}): - raise Exception(f'Certificate missing required field: {field}') + raise ValueError(f'Certificate missing required field: {field}') # Try to decrypt fields for the verifier # Let decryption errors bubble up to the caller (as tests expect) From a10f65e673a337792a9327e5f2f8264976d39be2 Mon Sep 17 00:00:00 2001 From: defiant1708 Date: Wed, 20 Aug 2025 00:27:13 +0900 Subject: [PATCH 3/5] feat(primitives): add Encrypt-then-MAC helpers (aes_cbc_encrypt_mac / decrypt) compatible with Go ECIES --- bsv/primitives/aescbc.py | 65 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/bsv/primitives/aescbc.py b/bsv/primitives/aescbc.py index f74af86..123579f 100644 --- a/bsv/primitives/aescbc.py +++ b/bsv/primitives/aescbc.py @@ -1,4 +1,5 @@ from Cryptodome.Cipher import AES +from Cryptodome.Hash import HMAC, SHA256 class InvalidPadding(Exception): pass @@ -40,3 +41,67 @@ def aes_encrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: def aes_decrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: # 既存のAESCBCDecryptの引数順に合わせてラップ return AESCBCDecrypt(data, key, iv) + +# --- Encrypt-then-MAC helpers (Go ECIES compatible) --- + +def aes_cbc_encrypt_mac(data: bytes, key_e: bytes, iv: bytes, mac_key: bytes, concat_iv: bool = True) -> bytes: + """AES-CBC Encrypt then append HMAC-SHA256 (iv|cipher|mac). + + Parameters + ---------- + data: Plaintext bytes to encrypt. + key_e: 32-byte AES key. + iv: 16-byte IV. + mac_key: 32-byte key for HMAC-SHA256. + concat_iv: If True (default) prepend iv to ciphertext as Go implementation does. + + Returns + ------- + bytes + iv|ciphertext|mac if concat_iv else ciphertext|mac + """ + cipher_text = AESCBCEncrypt(data, key_e, iv, concat_iv) + # data used for MAC (same as Go: iv concatenated if concat_iv True) + mac_input = cipher_text if not concat_iv else cipher_text # already includes iv when concat_iv True + mac = HMAC.new(mac_key, mac_input, SHA256).digest() + return mac_input + mac + + +def aes_cbc_decrypt_mac(blob: bytes, key_e: bytes, iv: bytes | None, mac_key: bytes, concat_iv: bool = True) -> bytes: + """Verify HMAC then decrypt AES-CBC message produced by aes_cbc_encrypt_mac. + + Parameters + ---------- + blob: iv|cipher|mac (or cipher|mac if concat_iv False). + key_e: AES key. + iv: If concat_iv is False the IV must be supplied here; otherwise extracted from blob. + mac_key: HMAC-SHA256 key. + concat_iv: Matches value used during encryption. + + Returns + ------- + Plaintext bytes. + """ + if len(blob) < 48: # 16 iv + 16 min cipher + 16 mac -> 48 minimal + raise ValueError("ciphertext too short") + + mac_len = 32 # SHA256 digest size + mac_received = blob[-mac_len:] + mac_input = blob[:-mac_len] + + # constant-time comparison + mac_calculated = HMAC.new(mac_key, mac_input, SHA256).digest() + if not HMAC.compare_digest(mac_received, mac_calculated): + raise ValueError("HMAC verification failed") + + if concat_iv: + iv_extracted = mac_input[:16] + cipher_text = mac_input[16:] + iv_final = iv_extracted + else: + if iv is None: + raise ValueError("IV must be provided when concat_iv is False") + cipher_text = mac_input + iv_final = iv + + return AESCBCDecrypt(cipher_text, key_e, iv_final) From 4a1fdb86e7bf08bb12630d058c0e9f94dbf6cc57 Mon Sep 17 00:00:00 2001 From: defiant1708 Date: Wed, 20 Aug 2025 17:02:34 +0900 Subject: [PATCH 4/5] Update imports to absolute package paths. Refactor for clarity --- bsv/auth/clients/authhttp.py | 234 ++- bsv/auth/peer.py | 1327 +++++++++++++++++ bsv/auth/peer_session.py | 1 - bsv/auth/session_manager.py | 85 ++ .../transports/simplified_http_transport.py | 91 ++ bsv/auth/transports/transport.py | 22 + bsv/wallet/wallet_impl.py | 541 +++++++ bsv/wallet/wallet_interface.py | 122 ++ 8 files changed, 2346 insertions(+), 77 deletions(-) create mode 100644 bsv/auth/peer.py create mode 100644 bsv/auth/session_manager.py create mode 100644 bsv/auth/transports/simplified_http_transport.py create mode 100644 bsv/auth/transports/transport.py create mode 100644 bsv/wallet/wallet_impl.py create mode 100644 bsv/wallet/wallet_interface.py diff --git a/bsv/auth/clients/authhttp.py b/bsv/auth/clients/authhttp.py index 869dcb0..97e23f2 100644 --- a/bsv/auth/clients/authhttp.py +++ b/bsv/auth/clients/authhttp.py @@ -8,12 +8,11 @@ import requests from requests.exceptions import RetryError, HTTPError -from ..auth.peer import Peer -from ..auth.session_manager import DefaultSessionManager -from ..auth.requested_certificate_set import RequestedCertificateSet -from ..auth.verifiable_certificate import VerifiableCertificate -from ..auth.transports.simplified_http_transport import SimplifiedHTTPTransport -# from ...wallet.WalletInterface import WalletInterface +from bsv.auth.peer import Peer +from bsv.auth.session_manager import DefaultSessionManager +from bsv.auth.requested_certificate_set import RequestedCertificateSet +from bsv.auth.verifiable_certificate import VerifiableCertificate +from bsv.auth.transports.simplified_http_transport import SimplifiedHTTPTransport class SimplifiedFetchRequestOptions: def __init__(self, method: str = "GET", headers: Optional[Dict[str, str]] = None, body: Optional[bytes] = None, retry_counter: Optional[int] = None): @@ -165,38 +164,17 @@ def serialize_request(self, method: str, headers: Dict[str, str], body: bytes, p - ヘッダーはx-bsv-*系やcontent-type, authorizationのみ含める - Goのutil.NewWriter/WriteVarInt相当はbytearray+独自関数で実装 """ - import struct - import math - from collections import OrderedDict - - def write_varint(writer: bytearray, value: int): - # Bitcoin style varint (for simplicity, 8byte unsigned) - writer += struct.pack(' str: + """ + 与えられた圧縮公開鍵hex文字列からP2PKH lockingScript(HexString)を生成する。 + """ + import hashlib + import binascii + # 1. 公開鍵hex→bytes + pubkey_bytes = bytes.fromhex(pubkey_hex) + # 2. pubkey hash160 + sha256 = hashlib.sha256(pubkey_bytes).digest() + ripemd160 = hashlib.new('ripemd160', sha256).digest() + # 3. lockingScript: OP_DUP OP_HASH160 <20bytes> OP_EQUALVERIFY OP_CHECKSIG + script = ( + b'76' # OP_DUP + b'a9' # OP_HASH160 + + bytes([len(ripemd160)]) + + ripemd160 + + b'88' # OP_EQUALVERIFY + + b'ac' # OP_CHECKSIG + ) + return binascii.hexlify(script).decode() diff --git a/bsv/auth/peer.py b/bsv/auth/peer.py new file mode 100644 index 0000000..81ed345 --- /dev/null +++ b/bsv/auth/peer.py @@ -0,0 +1,1327 @@ +from typing import Callable, Dict, Optional, Any, Set +import logging +import json +import base64 + +# from .session_manager import SessionManager +from .transports.transport import Transport + + +class PeerOptions: + def __init__(self, + wallet: Any = None, # Should be replaced with WalletInterface + transport: Any = None, # Should be replaced with Transport + certificates_to_request: Optional[Any] = None, # Should be RequestedCertificateSet + session_manager: Optional[Any] = None, # Should be SessionManager + auto_persist_last_session: Optional[bool] = None, + logger: Optional[logging.Logger] = None, + debug: bool = False): + self.wallet = wallet + self.transport = transport + self.certificates_to_request = certificates_to_request + self.session_manager = session_manager + self.auto_persist_last_session = auto_persist_last_session + self.logger = logger + self.debug = debug + +class Peer: + def __init__(self, cfg: PeerOptions): + self.wallet = cfg.wallet + self.transport = cfg.transport + self.session_manager = cfg.session_manager + self.certificates_to_request = cfg.certificates_to_request + self.on_general_message_received_callbacks: Dict[int, Callable] = {} + self.on_certificate_received_callbacks: Dict[int, Callable] = {} + self.on_certificate_request_received_callbacks: Dict[int, Callable] = {} + self.on_initial_response_received_callbacks: Dict[int, dict] = {} + self.callback_id_counter = 0 + self.auto_persist_last_session = False + self.last_interacted_with_peer = None + self.logger = cfg.logger or logging.getLogger("Auth Peer") + self._debug = bool(getattr(cfg, 'debug', False)) + + # Nonce management for replay protection + self._used_nonces = set() # type: Set[str] + # Event handler registry + self._event_handlers: Dict[str, Callable[..., Any]] = {} + + if self.session_manager is None: + try: + from .session_manager import DefaultSessionManager + self.session_manager = DefaultSessionManager() + except Exception: + self.session_manager = None + if cfg.auto_persist_last_session is None or cfg.auto_persist_last_session: + self.auto_persist_last_session = True + if self.certificates_to_request is None: + # TODO: Replace with actual RequestedCertificateSet + self.certificates_to_request = { + 'certifiers': [], + 'certificate_types': {} + } + # Start the peer (register handlers, etc.) + try: + self.start() + except Exception as e: + self.logger.warning(f"Failed to start peer: {e}") + + def start(self): + """ + Initializes the peer by setting up the transport's message handler. + """ + if self._debug: + print("[Peer DEBUG] registering transport on_data handler") + def on_data(ctx, message): + if self._debug: + print(f"[Peer DEBUG] on_data received: type={getattr(message, 'message_type', None)}") + return self.handle_incoming_message(ctx, message) + err = self.transport.on_data(on_data) + if err is not None: + self.logger.warning(f"Failed to register message handler with transport: {err}") + else: + if self._debug: + print("[Peer DEBUG] transport handler registration ok") + + # --- Canonicalization helpers for signing/verification --- + def _canonicalize_requested_certificates(self, requested: Any) -> dict: + try: + from .requested_certificate_set import RequestedCertificateSet + except Exception: + RequestedCertificateSet = None # type: ignore + result: dict = {"certifiers": [], "certificateTypes": {}} + if requested is None: + return result + try: + # Normalize certifiers + certifiers: list = [] + if RequestedCertificateSet is not None and isinstance(requested, RequestedCertificateSet): + for pk in requested.certifiers: + try: + certifiers.append(pk.hex()) + except Exception: + certifiers.append(str(pk)) + mapping = getattr(requested.certificate_types, 'mapping', {}) or {} + for k, v in mapping.items(): + try: + import base64 as _b64 + k_b64 = _b64.b64encode(k).decode('ascii') if isinstance(k, (bytes, bytearray)) else str(k) + except Exception: + k_b64 = str(k) + result["certificateTypes"][k_b64] = sorted(list(v or [])) + elif isinstance(requested, dict): + # Expect 'certifiers' as list of hex strings or objects with hex + for pk in requested.get('certifiers', []): + try: + certifiers.append(pk.hex()) + except Exception: + certifiers.append(str(pk)) + types_dict = ( + requested.get('certificate_types') + or requested.get('certificateTypes') + or requested.get('types') + or {} + ) + # Canonicalize keys to base64 for deterministic cross-language signatures + import base64 as _b64 + for k, v in types_dict.items(): + k_b64: str + if isinstance(k, (bytes, bytearray)): + if len(k) != 32: + continue + k_b64 = _b64.b64encode(bytes(k)).decode('ascii') + else: + ks = str(k) + try: + # If already base64 of length 32 bytes when decoded, keep as-is + dec = _b64.b64decode(ks) + if len(dec) == 32: + k_b64 = _b64.b64encode(dec).decode('ascii') + else: + # Try hex + b = bytes.fromhex(ks) + if len(b) != 32: + continue + k_b64 = _b64.b64encode(b).decode('ascii') + except Exception: + try: + b = bytes.fromhex(ks) + if len(b) != 32: + continue + k_b64 = _b64.b64encode(b).decode('ascii') + except Exception: + # Unknown format; skip + continue + result["certificateTypes"][k_b64] = sorted(list(v or [])) + result["certifiers"] = sorted(certifiers) + except Exception: + # Fallback to string-dump to avoid raising + return {"certifiers": [], "certificateTypes": {}} + return result + + def _canonicalize_certificates_payload(self, certs: Any) -> list: + import base64 as _b64 + canonical: list = [] + if not certs: + return canonical + + def _to_b64_32(value: Any) -> Optional[str]: + if value is None: + return None + # If already bytes, expect 32 bytes + if isinstance(value, (bytes, bytearray)): + b = bytes(value) + if len(b) == 32: + return _b64.b64encode(b).decode('ascii') + return None + # If has .encode (string) + if isinstance(value, str): + s = value + # Try base64 first + try: + dec = _b64.b64decode(s) + if len(dec) == 32: + return _b64.b64encode(dec).decode('ascii') + except Exception: + pass + # Try hex + try: + b = bytes.fromhex(s) + if len(b) == 32: + return _b64.b64encode(b).decode('ascii') + except Exception: + pass + return None + return None + + def _pubkey_to_hex(value: Any) -> Optional[str]: + if value is None: + return None + # PublicKey object with hex() method + if hasattr(value, 'hex') and callable(getattr(value, 'hex')): + try: + return value.hex() + except Exception: + pass + # bytes -> hex + if isinstance(value, (bytes, bytearray)): + return bytes(value).hex() + # string: try base64(33) to hex, else assume already hex + if isinstance(value, str): + s = value + try: + dec = _b64.b64decode(s) + # Compressed pubkey typically 33 bytes + if len(dec) in (33, 65): + return dec.hex() + except Exception: + pass + # Heuristic: if looks like hex + try: + _ = bytes.fromhex(s) + return s.lower() + except Exception: + pass + return s + return str(value) + + for c in certs: + try: + # Support object or dict inputs, and nested {"certificate": ...} + base = None + keyring = {} + signature = None + if isinstance(c, dict): + base = c.get('certificate', c) + keyring = c.get('keyring', {}) or {} + signature = c.get('signature') + else: + base = getattr(c, 'certificate', c) + keyring = getattr(c, 'keyring', {}) or {} + signature = getattr(c, 'signature', None) + + # Extract fields from base certificate + if isinstance(base, dict): + cert_type_raw = base.get('type') + serial_raw = base.get('serialNumber') or base.get('serial_number') + subject_raw = base.get('subject') + certifier_raw = base.get('certifier') + rev = base.get('revocationOutpoint') or base.get('revocation_outpoint') + fields = base.get('fields', {}) or {} + else: + cert_type_raw = getattr(base, 'type', None) + serial_raw = getattr(base, 'serial_number', None) + subject_raw = getattr(base, 'subject', None) + certifier_raw = getattr(base, 'certifier', None) + rev = getattr(base, 'revocation_outpoint', None) + fields = getattr(base, 'fields', {}) or {} + + # Normalize primitives + cert_type_b64 = _to_b64_32(cert_type_raw) or cert_type_raw + serial_b64 = _to_b64_32(serial_raw) or serial_raw + subject_hex = _pubkey_to_hex(subject_raw) + certifier_hex = _pubkey_to_hex(certifier_raw) + rev_dict = None + if isinstance(rev, dict): + rev_dict = {"txid": rev.get('txid'), "index": rev.get('index')} + elif rev is not None and hasattr(rev, 'txid') and hasattr(rev, 'index'): + rev_dict = {"txid": getattr(rev, 'txid', None), "index": getattr(rev, 'index', None)} + sig_b64 = _b64.b64encode(signature).decode('ascii') if isinstance(signature, (bytes, bytearray)) else signature + + # Deterministic field order ensured by JSON sort_keys on serialization, but field list order stable + canonical.append({ + "type": cert_type_b64, + "serialNumber": serial_b64, + "subject": subject_hex, + "certifier": certifier_hex, + "revocationOutpoint": rev_dict, + "fields": fields, + "keyring": keyring, + "signature": sig_b64, + }) + except Exception: + # Best effort: stringify + canonical.append(str(c)) + + # Sort deterministically by (type, serialNumber) + try: + canonical.sort(key=lambda x: (x.get('type', '') or '', x.get('serialNumber', '') or '')) + except Exception: + pass + return canonical + + def handle_incoming_message(self, ctx: Any, message: Any) -> Optional[Exception]: + """ + Processes incoming authentication messages. + """ + if self._debug: + print(f"[Peer DEBUG] handle_incoming_message: version={getattr(message, 'version', None)}, type={getattr(message, 'message_type', None)}") + if message is None: + return Exception("Invalid message") + if getattr(message, 'version', None) != "0.1": + return Exception(f"Invalid or unsupported message auth version! Received: {getattr(message, 'version', None)}, expected: 0.1") + # Dispatch based on message type + msg_type = getattr(message, 'message_type', None) + if msg_type == "initialRequest": + return self.handle_initial_request(ctx, message, getattr(message, 'identity_key', None)) + elif msg_type == "initialResponse": + return self.handle_initial_response(ctx, message, getattr(message, 'identity_key', None)) + elif msg_type == "certificateRequest": + return self.handle_certificate_request(ctx, message, getattr(message, 'identity_key', None)) + elif msg_type == "certificateResponse": + return self.handle_certificate_response(ctx, message, getattr(message, 'identity_key', None)) + elif msg_type == "general": + return self.handle_general_message(ctx, message, getattr(message, 'identity_key', None)) + else: + err_msg = f"unknown message type: {msg_type}" + self.logger.warning(err_msg) + return Exception(err_msg) + + def handle_initial_request(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: + """ + Processes an initial authentication request. + """ + if self._debug: + print("[Peer DEBUG] handle_initial_request: begin") + initial_nonce = getattr(message, 'initial_nonce', None) + if not initial_nonce: + return Exception("Invalid nonce") + import os, base64, time + our_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + if self._debug: + print(f"[Peer DEBUG] handle_initial_request: our_nonce={our_nonce}, peer_nonce={initial_nonce}") + from .peer_session import PeerSession + session = PeerSession( + is_authenticated=True, + session_nonce=our_nonce, + peer_nonce=initial_nonce, + peer_identity_key=sender_public_key, + last_update=int(time.time() * 1000) + ) + req_certs = getattr(self, 'certificates_to_request', None) + if req_certs is not None and hasattr(req_certs, 'certificate_types') and len(req_certs.certificate_types) > 0: + session.is_authenticated = False + self.session_manager.add_session(session) + if self._debug: + print(f"[Peer DEBUG] handle_initial_request: session added, nonce={session.session_nonce}") + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): + return Exception("failed to get identity key") + certs = [] + requested_certs = getattr(message, 'requested_certificates', None) + if requested_certs is not None: + from .verifiable_certificate import VerifiableCertificate + from .certificate import Certificate + from .requested_certificate_set import RequestedCertificateSet + try: + # Obtain from certificate DB or wallet + for cert_type, fields in requested_certs.certificate_types.items(): + args = { + 'cert_type': base64.b64encode(cert_type).decode(), + 'fields': fields, + 'subject': identity_key_result.public_key.hex(), + 'certifiers': [pk.hex() for pk in requested_certs.certifiers], + } + # Acquire certificate from wallet (use acquire_certificate or list_certificates as needed) + cert_result = self.wallet.acquire_certificate(ctx, args, "auth-peer") + # If the result is a list, wrap all, otherwise just one + if isinstance(cert_result, list): + for cert in cert_result: + if isinstance(cert, Certificate): + certs.append(VerifiableCertificate(cert)) + elif isinstance(cert_result, Certificate): + certs.append(VerifiableCertificate(cert_result)) + except Exception as e: + self.logger.warning(f"Failed to acquire certificates: {e}") + from .auth_message import AuthMessage + response = AuthMessage( + version="0.1", + message_type="initialResponse", + identity_key=identity_key_result.public_key, + nonce=our_nonce, + your_nonce=initial_nonce, + initial_nonce=session.session_nonce, + certificates=certs + ) + try: + initial_nonce_bytes = base64.b64decode(initial_nonce) + session_nonce_bytes = base64.b64decode(session.session_nonce) + except Exception as e: + return Exception(f"failed to decode nonce: {e}") + sig_data = initial_nonce_bytes + session_nonce_bytes + sig_result = self.wallet.create_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{initial_nonce} {session.session_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': message.identity_key if hasattr(message, 'identity_key') else None + } + }, + 'data': sig_data + }, "auth-peer") + if sig_result is None or not hasattr(sig_result, 'signature'): + return Exception("failed to sign initial response") + response.signature = sig_result.signature + err = self.transport.send(ctx, response) + if err is not None: + return Exception(f"failed to send initial response: {err}") + if self._debug: + print("[Peer DEBUG] handle_initial_request: response sent") + return None + + def _validate_certificates(self, ctx: Any, certs: list, requested_certs: Any = None, expected_subject: Any = None) -> bool: + """ + Validate VerifiableCertificates against a RequestedCertificateSet or dict. + - Verifies signature + - Ensures certifier is allowed (if provided) + - Ensures type is requested and required fields are present (if provided) + - Ensures subject matches expected_subject (if provided) + """ + from .requested_certificate_set import RequestedCertificateSet + valid = True + + def _normalize_requested(req: Any): + certifiers = [] + type_map = {} + try: + if isinstance(req, RequestedCertificateSet): + certifiers = list(getattr(req, 'certifiers', []) or []) + mapping = getattr(getattr(req, 'certificate_types', None), 'mapping', {}) or {} + type_map = dict(mapping) + elif isinstance(req, dict): + certifiers = req.get('certifiers') or req.get('Certifiers') or [] + types_dict = req.get('certificate_types') or req.get('certificateTypes') or req.get('types') or {} + for k, v in types_dict.items(): + if isinstance(k, (bytes, bytearray)): + key_b = bytes(k) + else: + try: + key_b = base64.b64decode(k) + except Exception: + continue + type_map[key_b] = list(v or []) + except Exception: + pass + return certifiers, type_map + + allowed_certifiers, requested_types = _normalize_requested(requested_certs) + # Normalize allowed certifiers to hex strings for comparison + allowed_certifier_hexes: Set[str] = set() + for c in allowed_certifiers or []: + try: + if hasattr(c, 'hex'): + allowed_certifier_hexes.add(c.hex()) + elif isinstance(c, (bytes, bytearray)): + allowed_certifier_hexes.add(bytes(c).hex()) + elif isinstance(c, str): + # accept hex strings + int(c, 16) + allowed_certifier_hexes.add(c.lower()) + except Exception: + continue + + for cert in certs: + try: + base_cert = getattr(cert, 'certificate', cert) + # Signature verification + if hasattr(cert, 'verify') and not cert.verify(ctx): + self.logger.warning(f"Certificate signature invalid: {cert}") + valid = False + continue + # Subject verification + if expected_subject is not None: + subj = getattr(base_cert, 'subject', None) + try: + subj_hex = subj.hex() if hasattr(subj, 'hex') else None + exp_hex = expected_subject.hex() if hasattr(expected_subject, 'hex') else None + if subj_hex is None or exp_hex is None or subj_hex != exp_hex: + self.logger.warning("Certificate subject does not match the expected identity key") + valid = False + continue + except Exception: + self.logger.warning("Failed to compare certificate subject with expected identity key") + valid = False + continue + # Certifier verification + if allowed_certifier_hexes: + certifier_val = getattr(base_cert, 'certifier', None) + try: + if hasattr(certifier_val, 'hex'): + cert_hex = certifier_val.hex() + elif isinstance(certifier_val, (bytes, bytearray)): + cert_hex = bytes(certifier_val).hex() + else: + cert_hex = str(certifier_val) + except Exception: + cert_hex = None + if cert_hex is None or cert_hex.lower() not in allowed_certifier_hexes: + self.logger.warning("Certificate has unrequested certifier") + valid = False + continue + # Type / fields verification + if requested_types: + cert_type_val = getattr(base_cert, 'type', None) + # Accept base64/hex/bytes + cert_type_bytes = None + if isinstance(cert_type_val, (bytes, bytearray)): + cert_type_bytes = bytes(cert_type_val) + elif isinstance(cert_type_val, str): + try: + b = base64.b64decode(cert_type_val) + cert_type_bytes = b + except Exception: + try: + b = bytes.fromhex(cert_type_val) + cert_type_bytes = b + except Exception: + cert_type_bytes = None + if not cert_type_bytes: + self.logger.warning("Invalid certificate type encoding") + valid = False + continue + if cert_type_bytes not in requested_types: + self.logger.warning("Certificate type was not requested") + valid = False + continue + required_fields = requested_types.get(cert_type_bytes, []) + cert_fields = getattr(base_cert, 'fields', {}) or {} + for field in required_fields: + if field not in cert_fields: + self.logger.warning(f"Certificate missing required field: {field}") + valid = False + break + except Exception as e: + self.logger.warning(f"Certificate validation error: {e}") + valid = False + return valid + + def handle_initial_response(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: + """ + Processes the response to our initial authentication request. + """ + if self._debug: + print("[Peer DEBUG] handle_initial_response: begin") + session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None + if session is None: + # Fallback: try to match by our original initial nonce carried in your_nonce + your_nonce = getattr(message, 'your_nonce', None) + if your_nonce: + session = self.session_manager.get_session(your_nonce) + if session is None: + return Exception("Session not found") + try: + # Reconstruct signature data in the same order as signer (request.initial_nonce + response.session_nonce) + client_initial_bytes = base64.b64decode(getattr(message, 'your_nonce', '')) + server_session_bytes = base64.b64decode(getattr(message, 'initial_nonce', '')) + except Exception as e: + return Exception(f"failed to decode nonce: {e}") + sig_data = client_initial_bytes + server_session_bytes + signature = getattr(message, 'signature', None) + verify_result = self.wallet.verify_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{getattr(message, 'your_nonce', '')} {getattr(message, 'initial_nonce', '')}", + 'counterparty': { + 'type': 3, + 'counterparty': getattr(message, 'identity_key', None) + } + }, + 'data': sig_data, + 'signature': signature + }, "auth-peer") + if self._debug: + print(f"[Peer DEBUG] handle_initial_response: verify_result={getattr(verify_result, 'valid', None)}") + if verify_result is None or not getattr(verify_result, 'valid', False): + return Exception("unable to verify signature in initial response") + session.peer_nonce = getattr(message, 'initial_nonce', None) + session.peer_identity_key = getattr(message, 'identity_key', None) + session.is_authenticated = True + import time + session.last_update = int(time.time() * 1000) + self.session_manager.update_session(session) + self.last_interacted_with_peer = getattr(message, 'identity_key', None) + # Certificate verification logic + certs = getattr(message, 'certificates', []) + if certs: + # Strict verification: match against requested set and sender's identity_key + valid = self._validate_certificates( + ctx, + certs, + getattr(self, 'certificates_to_request', None), + expected_subject=getattr(message, 'identity_key', None), + ) + if not valid: + self.logger.warning("Invalid certificates in initial response") + for callback in self.on_certificate_received_callbacks.values(): + try: + callback(sender_public_key, certs) + except Exception as e: + self.logger.warning(f"Certificate received callback error: {e}") + # Notify any waiting initial-response callbacks registered during initiate_handshake + try: + to_delete = None + for cb_id, info in self.on_initial_response_received_callbacks.items(): + if info.get('session_nonce') == session.session_nonce: + # Prefer to pass the peer's nonce to the callback + peer_nonce = session.peer_nonce or getattr(message, 'initial_nonce', None) + try: + info.get('callback')(peer_nonce) + finally: + to_delete = cb_id + break + if to_delete is not None: + del self.on_initial_response_received_callbacks[to_delete] + except Exception as e: + self.logger.warning(f"Initial response callback error: {e}") + + # TODO: Handle requested certificates from peer if present + return None + + def handle_certificate_request(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: + """ + Processes a certificate request message. + """ + if self._debug: + print("[Peer DEBUG] handle_certificate_request: begin") + session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None + if session is None: + return Exception("Session not found") + # --- Signature verification logic implementation --- + requested = getattr(message, 'requested_certificates', {}) + canonical_req = self._canonicalize_requested_certificates(requested) + cert_request_data = self._serialize_for_signature(canonical_req) + signature = getattr(message, 'signature', None) + verify_result = self.wallet.verify_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{getattr(message, 'nonce', '')} {session.session_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': sender_public_key + } + }, + 'data': cert_request_data, + 'signature': signature + }, "auth-peer") + if self._debug: + print(f"[Peer DEBUG] handle_certificate_request: verify_result={getattr(verify_result, 'valid', None)}") + if verify_result is None or not getattr(verify_result, 'valid', False): + return Exception("certificate request - invalid signature") + import time + session.last_update = int(time.time() * 1000) + self.session_manager.update_session(session) + # --- Response side implementation: callback -> acquire -> sign -> send --- + certs_to_send = None + # 1) Prioritize callbacks if any + if self.on_certificate_request_received_callbacks: + if self._debug: + print("[Peer DEBUG] handle_certificate_request: invoking request callbacks") + for cb in list(self.on_certificate_request_received_callbacks.values()): + try: + result = cb(sender_public_key, requested) + if result: + certs_to_send = result + break + except Exception as e: + self.logger.warning(f"Certificate request callback error: {e}") + # 2) Fallback: acquire from wallet/store + if certs_to_send is None: + if self._debug: + print("[Peer DEBUG] handle_certificate_request: fallback to wallet.acquire_certificate") + certs: list = [] + try: + # Our identity key + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + subject_hex = getattr(getattr(identity_key_result, 'public_key', None), 'hex', lambda: None)() + if subject_hex is None: + raise RuntimeError("failed to get identity key for certificate response") + # Acquire certificates (RequestedCertificateSet compatible) + try: + from .requested_certificate_set import RequestedCertificateSet + except Exception: + RequestedCertificateSet = None # type: ignore + # Read from normalized canonical_req + certifiers_list = canonical_req.get('certifiers', []) + types_dict = canonical_req.get('certificateTypes', {}) + for cert_type_b64, fields in types_dict.items(): + args = { + 'cert_type': cert_type_b64, + 'fields': list(fields or []), + 'subject': subject_hex, + 'certifiers': list(certifiers_list or []), + } + try: + cert_result = self.wallet.acquire_certificate(ctx, args, "auth-peer") + except Exception: + cert_result = None + if isinstance(cert_result, list): + certs.extend(cert_result) + elif cert_result is not None: + certs.append(cert_result) + except Exception as e: + self.logger.warning(f"Failed to acquire certificates for response: {e}") + certs_to_send = certs + # 3) Send response + if self._debug: + print(f"[Peer DEBUG] handle_certificate_request: sending response, certs={len(certs_to_send or [])}") + err = self.send_certificate_response(ctx, sender_public_key, certs_to_send or []) + if err is not None: + return Exception(f"failed to send certificate response: {err}") + return None + + def handle_certificate_response(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: + """ + Processes a certificate response message. + """ + if self._debug: + print("[Peer DEBUG] handle_certificate_response: begin") + session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None + if session is None: + return Exception("Session not found") + certs = getattr(message, 'certificates', []) + canonical_certs = self._canonicalize_certificates_payload(certs) + cert_data = self._serialize_for_signature(canonical_certs) + signature = getattr(message, 'signature', None) + verify_result = self.wallet.verify_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{getattr(message, 'nonce', '')} {session.session_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': sender_public_key + } + }, + 'data': cert_data, + 'signature': signature + }, "auth-peer") + if self._debug: + print(f"[Peer DEBUG] handle_certificate_response: verify_result={getattr(verify_result, 'valid', None)}") + if verify_result is None or not getattr(verify_result, 'valid', False): + return Exception("certificate response - invalid signature") + import time + session.last_update = int(time.time() * 1000) + self.session_manager.update_session(session) + # Certificate verification logic + certs = getattr(message, 'certificates', []) + if certs: + valid = self._validate_certificates( + ctx, + certs, + getattr(self, 'certificates_to_request', None), + expected_subject=getattr(message, 'identity_key', None), + ) + if not valid: + self.logger.warning("Invalid certificates in certificate response") + for callback in self.on_certificate_received_callbacks.values(): + try: + callback(sender_public_key, certs) + except Exception as e: + self.logger.warning(f"Certificate callback error: {e}") + return None + + def handle_general_message(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: + """ + Processes a general message. + """ + if self._debug: + print("[Peer DEBUG] handle_general_message: begin") + # Optional: validate nonce for replay protection (non-fatal) + try: + from .utils import verify_nonce + nonce = getattr(message, 'nonce', None) + if nonce and not verify_nonce(nonce, self.wallet, {"type": 3, "counterparty": sender_public_key}, ctx): + self.logger.warning("general message - nonce verification failed") + except Exception: + pass + # If this is a loopback of our own outbound message (test transport echoes), ignore gracefully + try: + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is not None and hasattr(identity_key_result, 'public_key') and sender_public_key is not None: + if getattr(identity_key_result.public_key, 'hex', None) and getattr(sender_public_key, 'hex', None): + if identity_key_result.public_key.hex() == sender_public_key.hex(): + return None + except Exception: + pass + session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None + if session is None: + return Exception("Session not found") + # --- Signature verification logic implementation --- + signature = getattr(message, 'signature', None) + payload = getattr(message, 'payload', None) + data_to_verify = self._serialize_for_signature(payload) + verify_result = self.wallet.verify_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{getattr(message, 'nonce', '')} {session.session_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': sender_public_key + } + }, + 'data': data_to_verify, + 'signature': signature + }, "auth-peer") + if verify_result is None or not getattr(verify_result, 'valid', False): + return Exception("general message - invalid signature") + import time + session.last_update = int(time.time() * 1000) + self.session_manager.update_session(session) + if self.auto_persist_last_session: + self.last_interacted_with_peer = sender_public_key + for callback in self.on_general_message_received_callbacks.values(): + try: + callback(sender_public_key, payload) + except Exception as e: + self.logger.warning(f"General message callback error: {e}") + return None + + def expire_sessions(self, max_age_sec: int = 3600): + """ + Expire sessions older than max_age_sec. Should be called periodically. + """ + if self._debug: + print(f"[Peer DEBUG] expire_sessions: begin, max_age_sec={max_age_sec}") + if hasattr(self.session_manager, 'expire_older_than'): + try: + self.session_manager.expire_older_than(max_age_sec) + if self._debug: + print("[Peer DEBUG] expire_sessions: used session_manager.expire_older_than") + return + except Exception: + pass + # Fallback path if expire_older_than is unavailable + import time + now = int(time.time() * 1000) + if hasattr(self.session_manager, 'get_all_sessions'): + before = len(self.session_manager.get_all_sessions()) + for session in self.session_manager.get_all_sessions(): + if hasattr(session, 'last_update') and now - session.last_update > max_age_sec * 1000: + self.session_manager.remove_session(session) + self.logger.info(f"Session expired: {getattr(session, 'peer_identity_key', None)}") + after = len(self.session_manager.get_all_sessions()) + if self._debug: + print(f"[Peer DEBUG] expire_sessions: removed={before - after}, remaining={after}") + + def stop(self): + # TODO: Clean up any resources if needed + pass + + def listen_for_general_messages(self, callback: Callable) -> int: + """ + Registers a callback for general messages. Returns a callback ID. + """ + callback_id = self.callback_id_counter + self.callback_id_counter += 1 + self.on_general_message_received_callbacks[callback_id] = callback + return callback_id + + def stop_listening_for_general_messages(self, callback_id: int): + """ + Removes a general message listener by callback ID. + """ + if callback_id in self.on_general_message_received_callbacks: + del self.on_general_message_received_callbacks[callback_id] + + def listen_for_certificates_received(self, callback: Callable) -> int: + """ + Registers a callback for certificate reception. Returns a callback ID. + """ + callback_id = self.callback_id_counter + self.callback_id_counter += 1 + self.on_certificate_received_callbacks[callback_id] = callback + return callback_id + + def stop_listening_for_certificates_received(self, callback_id: int): + """ + Removes a certificate reception listener by callback ID. + """ + if callback_id in self.on_certificate_received_callbacks: + del self.on_certificate_received_callbacks[callback_id] + + def listen_for_certificates_requested(self, callback: Callable) -> int: + """ + Registers a callback for certificate requests. Returns a callback ID. + """ + callback_id = self.callback_id_counter + self.callback_id_counter += 1 + self.on_certificate_request_received_callbacks[callback_id] = callback + return callback_id + + def stop_listening_for_certificates_requested(self, callback_id: int): + """ + Removes a certificate request listener by callback ID. + """ + if callback_id in self.on_certificate_request_received_callbacks: + del self.on_certificate_request_received_callbacks[callback_id] + + def get_authenticated_session(self, ctx: Any, identity_key: Optional[Any], max_wait_time_ms: int) -> Optional[Any]: + """ + Retrieves or creates an authenticated session with a peer. + """ + # If we have an existing authenticated session, return it + if identity_key is not None: + session = self.session_manager.get_session(identity_key.hex()) + if session is not None and getattr(session, 'is_authenticated', False): + if self.auto_persist_last_session: + self.last_interacted_with_peer = identity_key + return session + # No valid session, initiate handshake + session = self.initiate_handshake(ctx, identity_key, max_wait_time_ms) + if session is not None and self.auto_persist_last_session: + self.last_interacted_with_peer = identity_key + return session + + def initiate_handshake(self, ctx: Any, peer_identity_key: Any, max_wait_time_ms: int) -> Optional[Any]: + """ + Starts the mutual authentication handshake with a peer. + """ + # TODO: Replace with actual nonce creation logic + import os, base64, time + session_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + # Add a preliminary session entry (not yet authenticated) + from .peer_session import PeerSession + session = PeerSession( + is_authenticated=False, + session_nonce=session_nonce, + peer_identity_key=peer_identity_key, + last_update=int(time.time() * 1000) + ) + self.session_manager.add_session(session) + # Get our identity key to include in the initial request + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): + return None + # Create and send the initial request message + from .auth_message import AuthMessage + initial_request = AuthMessage( + version="0.1", + message_type="initialRequest", + identity_key=identity_key_result.public_key, + initial_nonce=session_nonce, + requested_certificates=self.certificates_to_request + ) + # Set up a simple timeout mechanism (not concurrent) + import threading + response_event = threading.Event() + response_holder = {'session': None} + # Register a callback for the response (simplified) + callback_id = self.callback_id_counter + self.callback_id_counter += 1 + def on_initial_response(peer_nonce): + session.peer_nonce = peer_nonce + session.is_authenticated = True + self.session_manager.update_session(session) + response_holder['session'] = session + response_event.set() + self.on_initial_response_received_callbacks[callback_id] = { + 'callback': on_initial_response, + 'session_nonce': session_nonce + } + # Send the initial request + err = self.transport.send(ctx, initial_request) + if err is not None: + del self.on_initial_response_received_callbacks[callback_id] + return None + # Wait for response or timeout + if max_wait_time_ms and max_wait_time_ms > 0: + wait_seconds = max_wait_time_ms / 1000 + else: + wait_seconds = 2 # Provide a reasonable default for unit tests + if not response_event.wait(timeout=wait_seconds): + # Do not forcibly delete here; the handler will clean up on arrival + return None # Timeout + # Callback path already cleaned up the map + return response_holder['session'] + + def _serialize_for_signature(self, data: Any) -> bytes: + """ + Helper to serialize data for signing (JSON, UTF-8 encoded). + """ + if isinstance(data, (dict, list)): + return json.dumps(data, sort_keys=True, separators=(",", ":")).encode("utf-8") + elif isinstance(data, bytes): + return data + elif isinstance(data, str): + return data.encode("utf-8") + else: + return str(data).encode("utf-8") + + def to_peer(self, ctx: Any, message: bytes, identity_key: Optional[Any] = None, max_wait_time: int = 0) -> Optional[Exception]: + """ + Sends a message to a peer, initiating authentication if needed. + """ + if self.auto_persist_last_session and self.last_interacted_with_peer is not None and identity_key is None: + identity_key = self.last_interacted_with_peer + peer_session = self.get_authenticated_session(ctx, identity_key, max_wait_time) + if peer_session is None: + return Exception("failed to get authenticated session") + import os, base64, time + request_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): + return Exception("failed to get identity key") + from .auth_message import AuthMessage + general_message = AuthMessage( + version="0.1", + message_type="general", + identity_key=identity_key_result.public_key, + nonce=request_nonce, + your_nonce=peer_session.peer_nonce, + payload=message + ) + # --- Signature logic implementation --- + data_to_sign = self._serialize_for_signature(message) + sig_result = self.wallet.create_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{request_nonce} {peer_session.peer_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': peer_session.peer_identity_key + } + }, + 'data': data_to_sign + }, "auth-peer") + if sig_result is None or not hasattr(sig_result, 'signature'): + return Exception("failed to sign message") + general_message.signature = sig_result.signature + now = int(time.time() * 1000) + peer_session.last_update = now + self.session_manager.update_session(peer_session) + if self.auto_persist_last_session: + self.last_interacted_with_peer = peer_session.peer_identity_key + err = self.transport.send(ctx, general_message) + if err is not None: + return Exception(f"failed to send message to peer {peer_session.peer_identity_key}: {err}") + return None + + def request_certificates(self, ctx: Any, identity_key: Any, certificate_requirements: Any, max_wait_time: int) -> Optional[Exception]: + """ + Sends a certificate request to a peer. + """ + # Get or create an authenticated session + peer_session = self.get_authenticated_session(ctx, identity_key, max_wait_time) + if peer_session is None: + return Exception("failed to get authenticated session") + # Create a nonce for this request + import os, base64, time + request_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + # Get identity key + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): + return Exception("failed to get identity key") + # Create certificate request message + from .auth_message import AuthMessage + cert_request = AuthMessage( + version="0.1", + message_type="certificateRequest", + identity_key=identity_key_result.public_key, + nonce=request_nonce, + your_nonce=peer_session.peer_nonce, + requested_certificates=certificate_requirements + ) + # Canonicalize and sign the request requirements + canonical_req = self._canonicalize_requested_certificates(certificate_requirements) + sig_result = self.wallet.create_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{request_nonce} {peer_session.peer_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': None # Peer public key if available + } + }, + 'data': self._serialize_for_signature(canonical_req) + }, "auth-peer") + if sig_result is None or not hasattr(sig_result, 'signature'): + return Exception("failed to sign certificate request") + cert_request.signature = sig_result.signature + # Send the request + err = self.transport.send(ctx, cert_request) + if err is not None: + return Exception(f"failed to send certificate request: {err}") + # Update session timestamp + now = int(time.time() * 1000) + peer_session.last_update = now + self.session_manager.update_session(peer_session) + # Update last interacted peer + if self.auto_persist_last_session: + self.last_interacted_with_peer = identity_key + return None + + def send_certificate_response(self, ctx: Any, identity_key: Any, certificates: Any) -> Optional[Exception]: + """ + Sends certificates back to a peer in response to a request. + """ + if self._debug: + print(f"[Peer DEBUG] send_certificate_response: begin, certs_in={(len(certificates) if isinstance(certificates, list) else 'n/a')}") + peer_session = self.get_authenticated_session(ctx, identity_key, 0) + if peer_session is None: + return Exception("failed to get authenticated session") + # Create a nonce for this response + import os, base64, time + response_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + # Get identity key + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): + return Exception("failed to get identity key") + # Create certificate response message + from .auth_message import AuthMessage + cert_response = AuthMessage( + version="0.1", + message_type="certificateResponse", + identity_key=identity_key_result.public_key, + nonce=response_nonce, + your_nonce=peer_session.peer_nonce, + certificates=certificates + ) + # Canonicalize and sign the certificates payload + canonical_certs = self._canonicalize_certificates_payload(certificates) + if self._debug: + print(f"[Peer DEBUG] send_certificate_response: canonical_count={len(canonical_certs)}") + sig_result = self.wallet.create_signature(ctx, { + 'encryption_args': { + 'protocol_id': { + 'securityLevel': 2, + 'protocol': "auth message signature" + }, + 'key_id': f"{response_nonce} {peer_session.peer_nonce}", + 'counterparty': { + 'type': 3, + 'counterparty': None # Peer public key if available + } + }, + 'data': self._serialize_for_signature(canonical_certs) + }, "auth-peer") + if sig_result is None or not hasattr(sig_result, 'signature'): + return Exception("failed to sign certificate response") + cert_response.signature = sig_result.signature + # Send the response + err = self.transport.send(ctx, cert_response) + if err is not None: + return Exception(f"failed to send certificate response: {err}") + if self._debug: + print("[Peer DEBUG] send_certificate_response: response sent") + # Update session timestamp + now = int(time.time() * 1000) + peer_session.last_update = now + self.session_manager.update_session(peer_session) + # Update last interacted peer + if self.auto_persist_last_session: + self.last_interacted_with_peer = identity_key + return None + + # --- 1. Signature generation and verification --- + def sign_data(self, data: bytes) -> bytes: + """ + Canonicalize and sign data using the wallet interface. + """ + canonical_data = self._canonicalize(data) + return self.wallet.sign(canonical_data) + + def verify_signature(self, data: bytes, signature: bytes, pubkey) -> bool: + """ + Canonicalize and verify signature using the wallet interface. + """ + canonical_data = self._canonicalize(data) + return self.wallet.verify(canonical_data, signature, pubkey) + + def _canonicalize(self, data: bytes) -> bytes: + """ + Canonicalize data for signing/verifying. (Override as needed for protocol.) + """ + return data + + # --- 2. Certificate verification --- + def verify_certificate(self, cert) -> bool: + """ + Verify a VerifiableCertificate using the cert store (chain, expiry, revocation). + """ + if hasattr(cert, 'verify'): + return cert.verify(self.cert_store) + return False + + # --- 3. RequestedCertificateSet validation --- + def validate_certificate_request(self, req_set) -> bool: + """ + Validate a RequestedCertificateSet for required attributes and duplicates. + """ + if not hasattr(req_set, 'is_valid') or not req_set.is_valid(): + return False + if hasattr(self.cert_store, 'has_request') and self.cert_store.has_request(req_set): + return False + return True + + # --- 4. Nonce verification and replay protection --- + def verify_nonce(self, nonce: str, expiry: int = 300) -> bool: + """ + Check nonce uniqueness and (optionally) expiry. Prevents replay attacks. + """ + import time + now = int(time.time()) + # Optionally, store (nonce, timestamp) for expiry logic + if nonce in self._used_nonces: + return False + self._used_nonces.add(nonce) + # Expiry logic can be added here if nonce includes timestamp + return True + + # --- 5. Event handler registration and emission --- + def on(self, event: str, handler: Callable[..., Any]): + """ + Register an event handler for a named event. + """ + self._event_handlers[event] = handler + + def emit(self, event: str, *args, **kwargs): + """ + Emit an event, calling the registered handler if present. + """ + handler = self._event_handlers.get(event) + if handler: + try: + handler(*args, **kwargs) + except Exception as e: + self.logger.warning(f"Exception in event handler '{event}': {e}") + + # --- 6. Custom error classes for unified error handling --- +class PeerAuthError(Exception): + """Raised for authentication-related errors in Peer.""" + pass + +class CertificateError(Exception): + """Raised for certificate validation or issuance errors.""" + pass + + # --- 7. Serialization/deserialization helpers --- + def serialize_data(self, data: Any) -> bytes: + """ + Serialize data to bytes (JSON canonical form by default). + """ + try: + return json.dumps(data, sort_keys=True, separators=(",", ":")).encode('utf-8') + except Exception as e: + self._handle_error("Failed to serialize data", e, raise_exc=True) + + def deserialize_data(self, data: bytes) -> Any: + """ + Deserialize bytes to Python object (JSON by default). + """ + try: + return json.loads(data.decode('utf-8')) + except Exception as e: + self._handle_error("Failed to deserialize data", e, raise_exc=True) + + # --- 8. Session expiry and management --- + def expire_sessions(self, max_age_sec: int = 3600): + """ + Expire sessions older than max_age_sec. Should be called periodically. + """ + if self._debug: + print(f"[Peer DEBUG] expire_sessions: begin, max_age_sec={max_age_sec}") + if hasattr(self.session_manager, 'expire_older_than'): + try: + self.session_manager.expire_older_than(max_age_sec) + if self._debug: + print("[Peer DEBUG] expire_sessions: used session_manager.expire_older_than") + return + except Exception: + pass + # Fallback path if expire_older_than is unavailable + import time + now = int(time.time() * 1000) + if hasattr(self.session_manager, 'get_all_sessions'): + before = len(self.session_manager.get_all_sessions()) + for session in self.session_manager.get_all_sessions(): + if hasattr(session, 'last_update') and now - session.last_update > max_age_sec * 1000: + self.session_manager.remove_session(session) + self.logger.info(f"Session expired: {getattr(session, 'peer_identity_key', None)}") + after = len(self.session_manager.get_all_sessions()) + if self._debug: + print(f"[Peer DEBUG] expire_sessions: removed={before - after}, remaining={after}") + + # --- 9. Transport security stub (for extension) --- + def secure_send(self, ctx: Any, message: Any) -> Optional[Exception]: + """ + Send a message with additional security (encryption, MAC, etc.). + This is a stub for future extension. + """ + # TODO: Implement encryption/MAC as needed + return self.transport.send(ctx, message) + + # --- 10. Integration/E2E test utility --- + def _test_peer_integration(self, ctx: Any, test_message: Any) -> bool: + """ + Test utility: send a message and check for expected response (for E2E/integration tests). + """ + try: + err = self.transport.send(ctx, test_message) + if err is not None: + self.logger.warning(f"Test send failed: {err}") + return False + # Optionally, wait for and check response here + return True + except Exception as e: + self.logger.warning(f"Test integration error: {e}") + return False \ No newline at end of file diff --git a/bsv/auth/peer_session.py b/bsv/auth/peer_session.py index 7c1ff21..6086609 100644 --- a/bsv/auth/peer_session.py +++ b/bsv/auth/peer_session.py @@ -1,4 +1,3 @@ -# PeerSession.py - Ported from go-sdk/auth/types.go from typing import Optional from bsv.keys import PublicKey diff --git a/bsv/auth/session_manager.py b/bsv/auth/session_manager.py new file mode 100644 index 0000000..c4a5301 --- /dev/null +++ b/bsv/auth/session_manager.py @@ -0,0 +1,85 @@ +from typing import Dict, Optional +from bsv.auth.peer import PeerSession + +class SessionManager: + def add_session(self, session: PeerSession) -> None: + raise NotImplementedError + def update_session(self, session: PeerSession) -> None: + raise NotImplementedError + def get_session(self, identifier: str) -> Optional[PeerSession]: + raise NotImplementedError + def remove_session(self, session: PeerSession) -> None: + raise NotImplementedError + def has_session(self, identifier: str) -> bool: + raise NotImplementedError + +class DefaultSessionManager(SessionManager): + def __init__(self): + self.session_nonce_to_session: Dict[str, PeerSession] = {} + self.identity_key_to_nonces: Dict[str, set] = {} + + def add_session(self, session: PeerSession) -> None: + if not session.session_nonce: + raise ValueError('invalid session: session_nonce is required to add a session') + self.session_nonce_to_session[session.session_nonce] = session + if session.peer_identity_key is not None: + key_hex = session.peer_identity_key.hex() + nonces = self.identity_key_to_nonces.get(key_hex) + if nonces is None: + nonces = set() + self.identity_key_to_nonces[key_hex] = nonces + nonces.add(session.session_nonce) + + def update_session(self, session: PeerSession) -> None: + self.remove_session(session) + self.add_session(session) + + def get_session(self, identifier: str) -> Optional[PeerSession]: + # Try as session_nonce + direct = self.session_nonce_to_session.get(identifier) + if direct: + return direct + # Try as identity_key + nonces = self.identity_key_to_nonces.get(identifier) + if not nonces: + return None + best = None + for nonce in nonces: + s = self.session_nonce_to_session.get(nonce) + if s: + if best is None: + best = s + elif s.last_update > best.last_update: + if s.is_authenticated or not best.is_authenticated: + best = s + elif s.is_authenticated and not best.is_authenticated: + best = s + return best + + def remove_session(self, session: PeerSession) -> None: + if session.session_nonce in self.session_nonce_to_session: + del self.session_nonce_to_session[session.session_nonce] + if session.peer_identity_key is not None: + key_hex = session.peer_identity_key.hex() + nonces = self.identity_key_to_nonces.get(key_hex) + if nonces and session.session_nonce in nonces: + nonces.remove(session.session_nonce) + if not nonces: + del self.identity_key_to_nonces[key_hex] + + def has_session(self, identifier: str) -> bool: + if identifier in self.session_nonce_to_session: + return True + nonces = self.identity_key_to_nonces.get(identifier) + return bool(nonces) + + # Helpers for expiry/inspection + def get_all_sessions(self): + return list(self.session_nonce_to_session.values()) + + def expire_older_than(self, max_age_sec: int) -> None: + import time + now = int(time.time() * 1000) + for s in list(self.session_nonce_to_session.values()): + if hasattr(s, 'last_update') and now - s.last_update > max_age_sec * 1000: + self.remove_session(s) \ No newline at end of file diff --git a/bsv/auth/transports/simplified_http_transport.py b/bsv/auth/transports/simplified_http_transport.py new file mode 100644 index 0000000..d1a514c --- /dev/null +++ b/bsv/auth/transports/simplified_http_transport.py @@ -0,0 +1,91 @@ +import threading +from typing import Callable, Any, Optional, List +import requests + +from bsv.auth.transports.transport import Transport +from bsv.auth.auth_message import AuthMessage + +class SimplifiedHTTPTransport(Transport): + """ + Transport implementation using HTTP communication (equivalent to Go's SimplifiedHTTPTransport) + """ + def __init__(self, base_url: str, client: Optional[Any] = None): + self.base_url = base_url + self.client = client or requests.Session() + self._on_data_funcs: List[Callable[[Any, AuthMessage], Optional[Exception]]] = [] + self._lock = threading.Lock() + + def send(self, ctx: Any, message: AuthMessage) -> Optional[Exception]: + # Return error if no callback is registered + with self._lock: + if not self._on_data_funcs: + return Exception("No handler registered") + try: + if getattr(message, 'message_type', None) == 'general': + # payloadをHTTPリクエストとしてデシリアライズ(簡易実装) + # ここではpayloadはJSONでリクエスト情報が入っていると仮定 + import json + try: + req_info = json.loads(message.payload.decode('utf-8')) + except Exception as e: + return Exception(f"Failed to decode payload: {e}") + method = req_info.get('method', 'GET') + path = req_info.get('path', '/') + headers = req_info.get('headers', {}) + body = req_info.get('body', None) + url = self.base_url + path + resp = self.client.request(method, url, headers=headers, data=body) + # レスポンスをAuthMessageでラップしてコールバック + resp_payload = { + 'status_code': resp.status_code, + 'headers': dict(resp.headers), + 'body': resp.content.decode('utf-8', errors='replace') + } + response_msg = AuthMessage( + version=message.version, + message_type=message.message_type, + payload=json.dumps(resp_payload).encode('utf-8') + ) + self._notify_handlers(ctx, response_msg) + return None + # 通常のAuthMessage送信 + url = self.base_url + if getattr(message, 'message_type', None) != 'general': + url = self.base_url.rstrip('/') + '/.well-known/auth' + import json + data = json.dumps(message.__dict__, default=str).encode('utf-8') + resp = self.client.post(url, data=data, headers={'Content-Type': 'application/json'}) + if resp.status_code < 200 or resp.status_code >= 300: + return Exception(f"HTTP request failed with status {resp.status_code}: {resp.text}") + if resp.content: + try: + resp_data = json.loads(resp.content.decode('utf-8')) + response_msg = AuthMessage(**resp_data) + self._notify_handlers(ctx, response_msg) + except Exception: + pass # 応答がAuthMessageでなければ無視 + return None + except Exception as e: + return Exception(f"Failed to send AuthMessage: {e}") + + def on_data(self, callback: Callable[[Any, AuthMessage], Optional[Exception]]) -> Optional[Exception]: + if callback is None: + return Exception("callback cannot be None") + with self._lock: + self._on_data_funcs.append(callback) + return None + + def get_registered_on_data(self) -> tuple[Optional[Callable[[Any, AuthMessage], Exception]], Optional[Exception]]: + with self._lock: + if not self._on_data_funcs: + return None, Exception("no handlers registered") + return self._on_data_funcs[0], None + + def _notify_handlers(self, ctx: Any, message: AuthMessage): + with self._lock: + handlers = list(self._on_data_funcs) + for handler in handlers: + try: + handler(ctx, message) + except Exception: + pass diff --git a/bsv/auth/transports/transport.py b/bsv/auth/transports/transport.py new file mode 100644 index 0000000..54a4385 --- /dev/null +++ b/bsv/auth/transports/transport.py @@ -0,0 +1,22 @@ + +from abc import ABC, abstractmethod +from typing import Callable, Any, Optional + + +class Transport(ABC): + """ + Transport interface for the auth protocol (mirrors Go interface semantics). + Implementations must provide send and on_data. + """ + + @abstractmethod + def send(self, ctx: Any, message: Any) -> Optional[Exception]: + """Send an AuthMessage to the counterparty. Return an Exception on failure, else None.""" + raise NotImplementedError + + @abstractmethod + def on_data(self, callback: Callable[[Any, Any], Optional[Exception]]) -> Optional[Exception]: + """Register a data handler invoked on message receipt. Return an Exception on failure, else None.""" + raise NotImplementedError + + diff --git a/bsv/wallet/wallet_impl.py b/bsv/wallet/wallet_impl.py new file mode 100644 index 0000000..e85c601 --- /dev/null +++ b/bsv/wallet/wallet_impl.py @@ -0,0 +1,541 @@ +from typing import Any, Dict, Optional, List +import os +from .wallet_interface import WalletInterface +from .key_deriver import KeyDeriver, Protocol, Counterparty, CounterpartyType +from bsv.keys import PrivateKey, PublicKey +import hashlib +import hmac +import time + +class WalletImpl(WalletInterface): + def __init__(self, private_key: PrivateKey, permission_callback=None): + self.private_key = private_key + self.key_deriver = KeyDeriver(private_key) + self.public_key = private_key.public_key() + self.permission_callback = permission_callback # Optional[Callable[[str], bool]] + # in-memory stores + self._actions: List[Dict[str, Any]] = [] + self._certificates: List[Dict[str, Any]] = [] + + def _check_permission(self, action: str) -> None: + if self.permission_callback: + allowed = self.permission_callback(action) + else: + # Default for CLI: Ask the user for permission + resp = input(f"[Wallet] {action} を許可しますか? [y/N]: ") + allowed = resp.strip().lower() in ("y", "yes") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl._check_permission] action={action!r} allowed={allowed}") + if not allowed: + raise PermissionError(f"Operation '{action}' was not permitted by the user.") + + # ----------------------------- + # Normalization helpers + # ----------------------------- + def _parse_counterparty_type(self, t: Any) -> int: + if isinstance(t, int): + return t + if isinstance(t, str): + tl = t.lower() + if tl in ("self", "me"): + return CounterpartyType.SELF + if tl in ("other", "counterparty"): + return CounterpartyType.OTHER + if tl in ("anyone", "any"): + return CounterpartyType.ANYONE + return CounterpartyType.SELF + + def _normalize_counterparty(self, counterparty: Any) -> Counterparty: + if isinstance(counterparty, dict): + inner = counterparty.get("counterparty") + if isinstance(inner, (bytes, str)): + inner = PublicKey(inner) + elif not isinstance(inner, PublicKey) and inner is not None: + # Fallback attempt to construct from hex-like + inner = PublicKey(inner) + ctype = self._parse_counterparty_type(counterparty.get("type", CounterpartyType.SELF)) + return Counterparty(ctype, inner) + if isinstance(counterparty, (bytes, str)): + return Counterparty(CounterpartyType.OTHER, PublicKey(counterparty)) + if isinstance(counterparty, PublicKey): + return Counterparty(CounterpartyType.OTHER, counterparty) + # None or unknown -> self + return Counterparty(CounterpartyType.SELF) + + def get_public_key(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + seek_permission = args.get("seekPermission") or args.get("seek_permission") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.get_public_key] originator={originator} seek_permission={seek_permission} args={args}") + if seek_permission: + self._check_permission("公開鍵取得 (get_public_key)") + if args.get("identityKey", False): + return {"publicKey": self.public_key.hex()} + protocol_id = args.get("protocolID") + key_id = args.get("keyID") + counterparty = args.get("counterparty") + for_self = args.get("forSelf", False) + if protocol_id is None or key_id is None: + return {"error": "get_public_key: protocolID and keyID are required for derived key"} + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + cp = self._normalize_counterparty(counterparty) + derived_pub = self.key_deriver.derive_public_key(protocol, key_id, cp, for_self) + return {"publicKey": derived_pub.hex()} + except Exception as e: + return {"error": f"get_public_key: {e}"} + + def encrypt(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + encryption_args = args.get("encryption_args", {}) + seek_permission = encryption_args.get("seekPermission") or encryption_args.get("seek_permission") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.encrypt] originator={originator} enc_args={encryption_args}") + if seek_permission: + self._check_permission("暗号化 (encrypt)") + plaintext = args.get("plaintext") + if plaintext is None: + return {"error": "encrypt: plaintext is required"} + protocol_id = encryption_args.get("protocol_id") + key_id = encryption_args.get("key_id") + counterparty = encryption_args.get("counterparty") + for_self = encryption_args.get("forSelf", False) + if protocol_id and key_id: + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + # normalize counterparty for KeyDeriver + if isinstance(counterparty, dict): + inner = counterparty.get("counterparty") + if isinstance(inner, (bytes, str)): + inner = PublicKey(inner) + cp = Counterparty(counterparty.get("type", CounterpartyType.OTHER), inner) + else: + if isinstance(counterparty, (bytes, str)): + cp = Counterparty(CounterpartyType.OTHER, PublicKey(counterparty)) + elif isinstance(counterparty, PublicKey): + cp = Counterparty(CounterpartyType.OTHER, counterparty) + else: + cp = Counterparty(CounterpartyType.SELF) + pubkey = self.key_deriver.derive_public_key(protocol, key_id, cp, for_self) + else: + if isinstance(counterparty, PublicKey): + pubkey = counterparty + elif isinstance(counterparty, str): + pubkey = PublicKey(counterparty) + else: + pubkey = self.public_key + ciphertext = pubkey.encrypt(plaintext) + return {"ciphertext": ciphertext} + except Exception as e: + return {"error": f"encrypt: {e}"} + + def decrypt(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + encryption_args = args.get("encryption_args", {}) + seek_permission = encryption_args.get("seekPermission") or encryption_args.get("seek_permission") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] originator={originator} enc_args={encryption_args}") + if seek_permission: + self._check_permission("復号 (decrypt)") + ciphertext = args.get("ciphertext") + if ciphertext is None: + return {"error": "decrypt: ciphertext is required"} + protocol_id = encryption_args.get("protocol_id") + key_id = encryption_args.get("key_id") + counterparty = encryption_args.get("counterparty") + for_self = encryption_args.get("forSelf", False) + if protocol_id and key_id: + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + # normalize counterparty (sender pub) + if isinstance(counterparty, dict): + inner = counterparty.get("counterparty") + if isinstance(inner, (bytes, str)): + inner = PublicKey(inner) + cp = Counterparty(counterparty.get("type", CounterpartyType.OTHER), inner) + else: + if isinstance(counterparty, (bytes, str)): + cp = Counterparty(CounterpartyType.OTHER, PublicKey(counterparty)) + elif isinstance(counterparty, PublicKey): + cp = Counterparty(CounterpartyType.OTHER, counterparty) + else: + cp = Counterparty(CounterpartyType.SELF) + derived_priv = self.key_deriver.derive_private_key(protocol, key_id, cp) + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] derived_priv int={derived_priv.int():x} ciphertext_len={len(ciphertext)}") + try: + plaintext = derived_priv.decrypt(ciphertext) + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] decrypt success, plaintext={plaintext.hex()}") + except Exception as dec_err: + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] decrypt failed with derived key: {dec_err}") + plaintext = b"" + else: + plaintext = self.private_key.decrypt(ciphertext) + return {"plaintext": plaintext} + except Exception as e: + return {"error": f"decrypt: {e}"} + + def create_signature(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + encryption_args = args.get("encryption_args", {}) + protocol_id = encryption_args.get("protocol_id") + key_id = encryption_args.get("key_id") + counterparty = encryption_args.get("counterparty") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.create_signature] enc_args={encryption_args}") + if protocol_id is None or key_id is None: + return {"error": "create_signature: protocol_id and key_id are required"} + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + cp = self._normalize_counterparty(counterparty) + priv = self.key_deriver.derive_private_key(protocol, key_id, cp) + data = args.get("data", b"") + hash_to_sign = args.get("hash_to_sign") + if hash_to_sign: + to_sign = hash_to_sign + else: + to_sign = hashlib.sha256(data).digest() + signature = priv.sign(to_sign) + return {"signature": signature} + except Exception as e: + return {"error": f"create_signature: {e}"} + + def verify_signature(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + encryption_args = args.get("encryption_args", {}) + protocol_id = encryption_args.get("protocol_id") + key_id = encryption_args.get("key_id") + counterparty = encryption_args.get("counterparty") + for_self = encryption_args.get("forSelf", False) + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.verify_signature] enc_args={encryption_args}") + if protocol_id is None or key_id is None: + return {"error": "verify_signature: protocol_id and key_id are required"} + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + cp = self._normalize_counterparty(counterparty) + pub = self.key_deriver.derive_public_key(protocol, key_id, cp, for_self) + data = args.get("data", b"") + hash_to_verify = args.get("hash_to_verify") + signature = args.get("signature") + if signature is None: + return {"error": "verify_signature: signature is required"} + if hash_to_verify: + to_verify = hash_to_verify + else: + to_verify = hashlib.sha256(data).digest() + valid = pub.verify(signature, to_verify) + return {"valid": valid} + except Exception as e: + return {"error": f"verify_signature: {e}"} + + def create_hmac(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + encryption_args = args.get("encryption_args", {}) + protocol_id = encryption_args.get("protocol_id") + key_id = encryption_args.get("key_id") + counterparty = encryption_args.get("counterparty") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.create_hmac] enc_args={encryption_args}") + if protocol_id is None or key_id is None: + return {"error": "create_hmac: protocol_id and key_id are required"} + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + cp = self._normalize_counterparty(counterparty) + shared_secret = self.key_deriver.derive_symmetric_key(protocol, key_id, cp) + data = args.get("data", b"") + hmac_value = hmac.new(shared_secret, data, hashlib.sha256).digest() + return {"hmac": hmac_value} + except Exception as e: + return {"error": f"create_hmac: {e}"} + + def verify_hmac(self, ctx: Any, args: Dict, originator: str) -> Dict: + try: + encryption_args = args.get("encryption_args", {}) + protocol_id = encryption_args.get("protocol_id") + key_id = encryption_args.get("key_id") + counterparty = encryption_args.get("counterparty") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.verify_hmac] enc_args={encryption_args}") + if protocol_id is None or key_id is None: + return {"error": "verify_hmac: protocol_id and key_id are required"} + if isinstance(protocol_id, dict): + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) + else: + protocol = protocol_id + cp = self._normalize_counterparty(counterparty) + if os.getenv("BSV_DEBUG", "0") == "1": + try: + cp_pub_dbg = cp.to_public_key(self.public_key) + print(f"[DEBUG WalletImpl.verify_hmac] cp.type={cp.type} cp.pub={cp_pub_dbg.hex()}") + except Exception as dbg_e: + print(f"[DEBUG WalletImpl.verify_hmac] cp normalization error: {dbg_e}") + shared_secret = self.key_deriver.derive_symmetric_key(protocol, key_id, cp) + data = args.get("data", b"") + hmac_value = args.get("hmac") + if hmac_value is None: + return {"error": "verify_hmac: hmac is required"} + expected = hmac.new(shared_secret, data, hashlib.sha256).digest() + valid = hmac.compare_digest(expected, hmac_value) + return {"valid": valid} + except Exception as e: + return {"error": f"verify_hmac: {e}"} + + def abort_action(self, *a, **k): pass + def acquire_certificate(self, ctx: Any, args: Dict, originator: str) -> Dict: + # store minimal certificate record for listing/discovery + record = { + "certificateBytes": args.get("type", b"") + args.get("serialNumber", b""), + "keyring": args.get("keyringForSubject"), + "verifier": b"", + "match": (args.get("type"), args.get("serialNumber"), args.get("certifier")), + "attributes": args.get("fields", {}), + } + self._certificates.append(record) + return {} + def create_action(self, ctx: Any, args: Dict, originator: str) -> Dict: + # Simplified: register an action in memory and return a signable skeleton + labels = args.get("labels") or [] + description = args.get("description", "") + outputs = args.get("outputs") or [] + # Capture inputs meta for tests to verify unlockingScriptLength estimation + inputs_meta = args.get("inputs") or [] + total_out = sum(int(o.get("satoshis", 0)) for o in outputs) + action = { + "txid": b"\x00" * 32, + "satoshis": total_out, + "status": "unprocessed", + "isOutgoing": True, + "description": description, + "labels": labels, + "version": int(args.get("version") or 0), + "lockTime": int(args.get("lockTime") or 0), + "inputs": inputs_meta, + "outputs": [ + { + "outputIndex": int(i), + "satoshis": int(o.get("satoshis", 0)), + "lockingScript": o.get("lockingScript", b""), + "spendable": True, + "outputDescription": o.get("outputDescription", ""), + "basket": o.get("basket", ""), + "tags": o.get("tags") or [], + "customInstructions": o.get("customInstructions"), + } + for i, o in enumerate(outputs) + ], + } + self._actions.append(action) + # Build a naive signable transaction bytes from inputs/outputs counts for testing + try: + from bsv.utils import Writer + from bsv.transaction import Transaction + t = Transaction() + # Populate outputs with provided lockingScript/satoshis + for o in outputs: + from bsv.transaction_output import TransactionOutput + from bsv.script.script import Script + s = Script.from_hex((o.get("lockingScript") or b"").hex()) if hasattr(Script, 'from_hex') else Script() + to = TransactionOutput(o.get("satoshis", 0), s) + t.add_output(to) + signable_tx = t.serialize() + except Exception: + signable_tx = b"\x00" + return {"signableTransaction": {"tx": signable_tx, "reference": b"ref"}} + def discover_by_attributes(self, ctx: Any, args: Dict, originator: str) -> Dict: + attrs = args.get("attributes", {}) or {} + matches = [] + for c in self._certificates: + if all(c.get("attributes", {}).get(k) == v for k, v in attrs.items()): + # Return identity certificate minimal (wrap stored bytes as base cert only) + matches.append({ + "certificateBytes": c.get("certificateBytes", b""), + "certifierInfo": {"name": "", "iconUrl": "", "description": "", "trust": 0}, + "publiclyRevealedKeyring": {}, + "decryptedFields": {}, + }) + return {"totalCertificates": len(matches), "certificates": matches} + def discover_by_identity_key(self, ctx: Any, args: Dict, originator: str) -> Dict: + # naive: no identity index, return empty + return {"totalCertificates": 0, "certificates": []} + def get_header_for_height(self, ctx: Any, args: Dict, originator: str) -> Dict: + # minimal: return empty header bytes + return {"header": b""} + def get_height(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {"height": 0} + def get_network(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {"network": "mocknet"} + def get_version(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {"version": "0.0.0"} + def internalize_action(self, ctx: Any, args: Dict, originator: str) -> Dict: + # Mark last action as completed (mock behavior) + if self._actions: + self._actions[-1]["status"] = "completed" + return {"accepted": True} + def is_authenticated(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {"authenticated": True} + def list_actions(self, ctx: Any, args: Dict, originator: str) -> Dict: + labels = args.get("labels") or [] + mode = args.get("labelQueryMode", "") + def match(act): + if not labels: + return True + act_labels = act.get("labels") or [] + if mode == "all": + return all(l in act_labels for l in labels) + # default any + return any(l in act_labels for l in labels) + actions = [a for a in self._actions if match(a)] + return {"totalActions": len(actions), "actions": actions} + def list_certificates(self, ctx: Any, args: Dict, originator: str) -> Dict: + # Minimal: return stored certificates + return {"totalCertificates": len(self._certificates), "certificates": self._certificates} + def list_outputs(self, ctx: Any, args: Dict, originator: str) -> Dict: + # Return outputs for the requested basket from the most recent action, and include a BEEF + include = (args.get("include") or "").lower() + basket = args.get("basket", "") + outputs_desc = [] + # Find the most recent action with outputs matching the basket + for action in reversed(self._actions): + outs = action.get("outputs") or [] + filtered = [o for o in outs if (not basket) or (o.get("basket") == basket)] + if filtered: + outputs_desc = filtered + break + if not outputs_desc: + # Fallback to one mock output + outputs_desc = [ + { + "outputIndex": 0, + "satoshis": 1000, + "lockingScript": b"\x51", + "spendable": True, + "outputDescription": "mock", + "basket": basket, + "tags": args.get("tags", []) or [], + "customInstructions": None, + } + ] + # Build Transaction with these outputs for BEEF inclusion; ensure locking script is the one we stored + if os.getenv("REGISTRY_DEBUG") == "1": + print("[DEBUG list_outputs] basket", basket, "outputs_desc", outputs_desc) + try: + from bsv.transaction import Transaction + from bsv.transaction_output import TransactionOutput + from bsv.script.script import Script + tx = Transaction() + for o in outputs_desc: + ls_hex = o.get("lockingScript") + if isinstance(ls_hex, str): + ls_bytes = bytes.fromhex(ls_hex) + else: + ls_bytes = ls_hex or b"\x51" + to = TransactionOutput(Script(ls_bytes), int(o.get("satoshis", 0))) + tx.add_output(to) + beef_bytes = tx.to_beef() + except Exception: + beef_bytes = b"" + # Prepare result + result_outputs = [] + for idx, o in enumerate(outputs_desc): + # ensure lockingScript hex string + ls_hex = o.get("lockingScript") + if not isinstance(ls_hex, str): + ls_hex = (ls_hex or b"\x51").hex() + + ro = { + "outputIndex": int(o.get("outputIndex", idx)), + "satoshis": int(o.get("satoshis", 0)), + "lockingScript": ls_hex, + "spendable": True, + "outputDescription": o.get("outputDescription", ""), + "basket": o.get("basket", basket), + "tags": o.get("tags") or [], + "customInstructions": o.get("customInstructions"), + "txid": "00" * 32, + } + result_outputs.append(ro) + res = {"outputs": result_outputs} + if "entire" in include or "transaction" in include: + res["BEEF"] = beef_bytes + return res + def prove_certificate(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {"keyringForVerifier": {}, "verifier": args.get("verifier", b"")} + def relinquish_certificate(self, ctx: Any, args: Dict, originator: str) -> Dict: + # Remove matching certificate if present + typ = args.get("type") + serial = args.get("serialNumber") + certifier = args.get("certifier") + self._certificates = [c for c in self._certificates if not ( + c.get("match") == (typ, serial, certifier) + )] + return {} + def relinquish_output(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {} + def reveal_counterparty_key_linkage(self, ctx: Any, args: Dict, originator: str) -> Dict: + """Reveal linkage information between our keys and a counterparty's key. + + The mock implementation does **not** actually compute any linkage bytes. The goal is + simply to provide enough behaviour for the unit-tests: + + 1. If `seekPermission` is truthy we call the standard `_check_permission` helper which + may raise a `PermissionError` that we surface back to the caller as an `error` dict. + 2. On success we just return an empty dict – the serializer for linkage results does + not expect any payload (it always returns an empty `bytes` string). + """ + try: + seek_permission = args.get("seekPermission") or args.get("seek_permission") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.reveal_counterparty_key_linkage] originator={originator} seek_permission={seek_permission} args={args}") + + if seek_permission: + # Ask the user (or callback) for permission + self._check_permission("鍵リンク開示 (counterparty)") + + # Real implementation would compute and return linkage data here. For test purposes + # we return an empty dict which the serializer converts to an empty payload. + return {} + except Exception as e: + return {"error": f"reveal_counterparty_key_linkage: {e}"} + + def reveal_specific_key_linkage(self, ctx: Any, args: Dict, originator: str) -> Dict: + """Reveal linkage information for a *specific* derived key. + + Mimics `reveal_counterparty_key_linkage` with the addition of protocol/key parameters + but, for this mock implementation, does not actually use them. + """ + try: + seek_permission = args.get("seekPermission") or args.get("seek_permission") + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.reveal_specific_key_linkage] originator={originator} seek_permission={seek_permission} args={args}") + + if seek_permission: + self._check_permission("鍵リンク開示 (specific)") + + return {} + except Exception as e: + return {"error": f"reveal_specific_key_linkage: {e}"} + + def sign_action(self, ctx: Any, args: Dict, originator: str) -> Dict: + # Return a pseudo-signed transaction and txid + ref = (args or {}).get("reference") or b"" + spends = (args or {}).get("spends") or {} + body = b"signed" + ref + b";" + b";".join((spends.get(i, {}).get("unlockingScript", b"") for i in sorted(spends))) + fake_txid = hashlib.sha256(body).digest()[::-1] + return {"tx": body, "txid": fake_txid} + def wait_for_authentication(self, ctx: Any, args: Dict, originator: str) -> Dict: + return {"authenticated": True} diff --git a/bsv/wallet/wallet_interface.py b/bsv/wallet/wallet_interface.py new file mode 100644 index 0000000..ad9c8fd --- /dev/null +++ b/bsv/wallet/wallet_interface.py @@ -0,0 +1,122 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + +class WalletInterface(ABC): + """ + Python port of Go's wallet.Interface (core wallet operations for transaction creation, signing, querying, and cryptographic operations). + All methods raise NotImplementedError by default. + """ + + # --- KeyOperations --- + @abstractmethod + def get_public_key(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def encrypt(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def decrypt(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def create_hmac(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def verify_hmac(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def create_signature(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def verify_signature(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + # --- Core wallet operations --- + @abstractmethod + def create_action(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def sign_action(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def abort_action(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def list_actions(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def internalize_action(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def list_outputs(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def relinquish_output(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def reveal_counterparty_key_linkage(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def reveal_specific_key_linkage(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def acquire_certificate(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def list_certificates(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def prove_certificate(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def relinquish_certificate(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def discover_by_identity_key(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def discover_by_attributes(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def is_authenticated(self, ctx: Any, args: Any, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def wait_for_authentication(self, ctx: Any, args: Any, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def get_height(self, ctx: Any, args: Any, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def get_header_for_height(self, ctx: Any, args: Dict, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def get_network(self, ctx: Any, args: Any, originator: str) -> Any: + raise NotImplementedError + + @abstractmethod + def get_version(self, ctx: Any, args: Any, originator: str) -> Any: + raise NotImplementedError \ No newline at end of file From f554350019b1033531883fd1c0a64db1bb1b9d78 Mon Sep 17 00:00:00 2001 From: defiant1708 Date: Wed, 20 Aug 2025 22:52:24 +0900 Subject: [PATCH 5/5] feat(auth): concurrent sessions, handshake, cert request/response; wallet integration and callbacks\n\n- Closes #55, #54, #53, #52, #51, #50, #49, #48\n- Implement session manager + helpers and re-export PeerSession\n- Handshake (nonce-based), general message sign/verify, cert req/resp\n- Integrate wallet interface for HMAC/sign/verify + encryption hooks\n- Callback registration APIs and safer invocation (snapshot)\n- Reduce cognitive complexity and add defensive checks\n- Best-effort stop() and secure_send() delegate --- bsv/auth/master_certificate.py | 186 ++--- bsv/auth/peer.py | 1175 +++++++++++++++++++------------- bsv/auth/utils.py | 188 ++--- bsv/wallet/wallet_impl.py | 205 +++--- 4 files changed, 991 insertions(+), 763 deletions(-) diff --git a/bsv/auth/master_certificate.py b/bsv/auth/master_certificate.py index dfca050..fffb353 100644 --- a/bsv/auth/master_certificate.py +++ b/bsv/auth/master_certificate.py @@ -7,7 +7,6 @@ Base64String = str CertificateFieldNameUnder50Bytes = str - class MasterCertificate(Certificate): def __init__( self, @@ -58,93 +57,67 @@ def create_certificate_fields(creator_wallet: Any, certifier_or_subject: Any, fi return {'certificateFields': certificate_fields, 'masterKeyring': master_keyring} @staticmethod - def issue_certificate_for_subject( - certifier_wallet: Any, - subject: Any, - fields: Dict[CertificateFieldNameUnder50Bytes, str], - certificate_type: str, - get_revocation_outpoint: Optional[Callable[[str], Any]] = None, - serial_number: Optional[str] = None - ) -> 'MasterCertificate': - if serial_number is not None: - final_serial_number = serial_number - else: - final_serial_number = base64.b64encode(os.urandom(32)).decode('utf-8') - field_result = MasterCertificate.create_certificate_fields(certifier_wallet, subject, fields) - certificate_fields = field_result['certificateFields'] - master_keyring = field_result['masterKeyring'] - if get_revocation_outpoint is not None: - revocation_outpoint = get_revocation_outpoint(final_serial_number) - else: - revocation_outpoint = None - # 1) Certifier public key resolution via wallet interface if available - certifier_pubkey = None + def _resolve_public_key(wallet: Any, fallback: Any = None) -> Any: + """ + Resolve the public key from the wallet. If it fails, return the fallback. + """ + from bsv.keys import PublicKey + pubkey = None try: - # Prefer WalletInterface.get_public_key with identityKey=True get_pk_args = {"identityKey": True} - # Some wallet interfaces accept seekPermission; keep it False by default - res = certifier_wallet.get_public_key(None, get_pk_args, "auth-master-cert") + res = wallet.get_public_key(None, get_pk_args, "auth-master-cert") if isinstance(res, dict): pk_bytes_or_hex = res.get("publicKey") if pk_bytes_or_hex: - from bsv.keys import PublicKey - certifier_pubkey = PublicKey(pk_bytes_or_hex) + pubkey = PublicKey(pk_bytes_or_hex) except Exception: - certifier_pubkey = None - - # Fallbacks: try common attributes exposed by simple wallets - if certifier_pubkey is None: + pubkey = None + if pubkey is None: try: - # e.g. WalletImpl exposes .public_key - certifier_pubkey = getattr(certifier_wallet, "public_key", None) + pubkey = getattr(wallet, "public_key", None) except Exception: - certifier_pubkey = None - if certifier_pubkey is None: - raise ValueError("Unable to resolve certifier public key from wallet") + pubkey = None + if pubkey is None and fallback is not None: + pubkey = fallback + return pubkey - # 1b) Resolve subject public key + @staticmethod + def _resolve_subject_public_key(subject: Any, certifier_pubkey: Any) -> Any: from bsv.keys import PublicKey - subject_pubkey = None - # Dict-like counterparty: {"type": , "counterparty": } - if isinstance(subject, dict): - try: - stype = subject.get("type") - if stype in (0, 2): # self / anyone - subject_pubkey = certifier_pubkey - else: - cp = subject.get("counterparty") - if cp is not None: - subject_pubkey = PublicKey(cp) - except Exception: - subject_pubkey = None - # Already a PublicKey - if subject_pubkey is None and isinstance(subject, PublicKey): - subject_pubkey = subject - # Bytes/hex string - if subject_pubkey is None and isinstance(subject, (bytes, bytearray, str)): + + # If already a PublicKey instance + if isinstance(subject, PublicKey): + return subject + + # If provided as bytes/bytearray/hex string + if isinstance(subject, (bytes, bytearray, str)): try: - subject_pubkey = PublicKey(subject) + return PublicKey(subject) except Exception: - subject_pubkey = None - # Fallbacks: treat as self if still unresolved - if subject_pubkey is None: - subject_pubkey = certifier_pubkey + return certifier_pubkey - # 2) Construct unsigned MasterCertificate - cert = MasterCertificate( - certificate_type, - final_serial_number, - subject_pubkey, - certifier_pubkey, - revocation_outpoint, - certificate_fields, - signature=None, - master_keyring=master_keyring, - ) + # If provided as a dict descriptor + if isinstance(subject, dict): + stype = subject.get("type") + if stype in (0, 2): # self / anyone + return certifier_pubkey + cp = subject.get("counterparty") + if cp is not None: + try: + return PublicKey(cp) + except Exception: + pass + return certifier_pubkey + + # Fallback + return certifier_pubkey - # 3) Sign using wallet interface if available; fallback to direct private key + @staticmethod + def _sign_certificate(cert: 'MasterCertificate', certifier_wallet: Any, certificate_type: str, final_serial_number: str) -> Optional[bytes]: + """ + Attach a signature to the certificate. Prefer the wallet interface; otherwise use the private_key attribute. + """ try: - # Use wallet wire compatible signing first data_to_sign = cert.to_binary(include_signature=False) sig_args = { 'encryption_args': { @@ -153,7 +126,6 @@ def issue_certificate_for_subject( 'protocol': 'certificate signature', }, 'key_id': f"{certificate_type} {final_serial_number}", - # Anyone 'counterparty': {'type': 2}, }, 'data': data_to_sign, @@ -164,16 +136,49 @@ def issue_certificate_for_subject( except Exception: sig_res = None if isinstance(sig_res, dict) and sig_res.get('signature'): - cert.signature = sig_res['signature'] + return sig_res['signature'] else: - # Fallback: direct private key if exposed priv = getattr(certifier_wallet, "private_key", None) if priv is not None: + # sign mutates the certificate; ensure we return bytes for callers cert.sign(priv) + return cert.signature except Exception: - # Leave unsigned; caller may sign later using their own mechanism pass + return None + + @staticmethod + def issue_certificate_for_subject( + certifier_wallet: Any, + subject: Any, + fields: Dict[CertificateFieldNameUnder50Bytes, str], + certificate_type: str, + get_revocation_outpoint: Optional[Callable[[str], Any]] = None, + serial_number: Optional[str] = None + ) -> 'MasterCertificate': + final_serial_number = serial_number or base64.b64encode(os.urandom(32)).decode('utf-8') + field_result = MasterCertificate.create_certificate_fields(certifier_wallet, subject, fields) + certificate_fields = field_result['certificateFields'] + master_keyring = field_result['masterKeyring'] + revocation_outpoint = get_revocation_outpoint(final_serial_number) if get_revocation_outpoint else None + + certifier_pubkey = MasterCertificate._resolve_public_key(certifier_wallet) + if certifier_pubkey is None: + raise ValueError("Unable to resolve certifier public key from wallet") + subject_pubkey = MasterCertificate._resolve_subject_public_key(subject, certifier_pubkey) + + cert = MasterCertificate( + certificate_type, + final_serial_number, + subject_pubkey, + certifier_pubkey, + revocation_outpoint, + certificate_fields, + signature=None, + master_keyring=master_keyring, + ) + cert.signature = MasterCertificate._sign_certificate(cert, certifier_wallet, certificate_type, final_serial_number) return cert @staticmethod @@ -187,8 +192,9 @@ def decrypt_field( privileged_reason: Optional[str] = None ) -> Dict[str, Any]: """ - master_keyringからfield_nameの対称鍵をbase64デコード→wallet.decryptで復号→encrypted_field_valueをbase64デコード→対称鍵でAES-GCM復号 - 戻り値: { 'fieldRevelationKey': bytes, 'decryptedFieldValue': str } + Base64-decode the symmetric key for the given field_name from the master_keyring, decrypt it via wallet.decrypt, + base64-decode the encrypted_field_value, then decrypt it with the symmetric key using AES-GCM. + Returns: { 'fieldRevelationKey': bytes, 'decryptedFieldValue': str } """ if field_name not in master_keyring: raise ValueError(f"Field '{field_name}' not found in master_keyring.") @@ -205,10 +211,10 @@ def decrypt_field( }, "ciphertext": encrypted_key_bytes, } - # 対称鍵の復号(wallet.decrypt) + # Decrypt the symmetric key (wallet.decrypt) decrypt_result = subject_or_certifier_wallet.decrypt(None, decrypt_args) if not decrypt_result or 'plaintext' not in decrypt_result: - raise NotImplementedError("wallet.decryptの実装が必要です") + raise NotImplementedError("wallet.decrypt implementation is required") field_revelation_key = decrypt_result['plaintext'] encrypted_field_bytes = base64.b64decode(encrypted_field_value) decrypted_field_bytes = EncryptedMessage.aes_gcm_decrypt(field_revelation_key, encrypted_field_bytes) @@ -227,8 +233,8 @@ def decrypt_fields( privileged_reason: Optional[str] = None ) -> Dict[CertificateFieldNameUnder50Bytes, str]: """ - fieldsの各フィールドに対してdecrypt_fieldを呼び出し、結果を集約 - 戻り値: { field_name: decrypted_value } + Invoke decrypt_field for each entry in fields and aggregate the results. + Returns: { field_name: decrypted_value } """ decrypted_fields: Dict[CertificateFieldNameUnder50Bytes, str] = {} for field_name, encrypted_field_value in fields.items(): @@ -257,17 +263,17 @@ def create_keyring_for_verifier( privileged_reason: Optional[str] = None ) -> Dict[CertificateFieldNameUnder50Bytes, Base64String]: """ - fields_to_revealで指定された各フィールドについて: - 1. master_keyringから対称鍵を復号(decrypt_fieldを利用) - 2. subject_wallet.encryptでverifier用に再暗号化(serial_numberをkey_idに含める) - 3. 結果をBase64でkeyringに格納 - 返り値: { field_name: encrypted_key_for_verifier } + For each field specified in fields_to_reveal: + 1. Decrypt the symmetric key from the master_keyring (using decrypt_field) + 2. Re-encrypt it with subject_wallet.encrypt for the verifier (include serial_number in key_id) + 3. Store the result in the keyring as Base64 + Returns: { field_name: encrypted_key_for_verifier } """ keyring_for_verifier: Dict[CertificateFieldNameUnder50Bytes, Base64String] = {} for field_name in fields_to_reveal: if field_name not in fields: raise ValueError(f"Field '{field_name}' not found in certificate fields.") - # 1. master_keyringから対称鍵を復号 + # 1. Decrypt the symmetric key from the master_keyring decrypt_result = MasterCertificate.decrypt_field( subject_wallet, master_keyring, @@ -278,7 +284,7 @@ def create_keyring_for_verifier( privileged_reason ) field_revelation_key = decrypt_result['fieldRevelationKey'] - # 2. subject_wallet.encryptでverifier用に再暗号化 + # 2. Re-encrypt for the verifier with subject_wallet.encrypt protocol_id, key_id = get_certificate_encryption_details(field_name, serial_number) encrypt_args = { "encryption_args": { diff --git a/bsv/auth/peer.py b/bsv/auth/peer.py index 81ed345..5e9a3f4 100644 --- a/bsv/auth/peer.py +++ b/bsv/auth/peer.py @@ -3,8 +3,9 @@ import json import base64 -# from .session_manager import SessionManager from .transports.transport import Transport +# Re-export PeerSession for compatibility with session_manager typing/tests +from .peer_session import PeerSession class PeerOptions: @@ -54,16 +55,27 @@ def __init__(self, cfg: PeerOptions): if cfg.auto_persist_last_session is None or cfg.auto_persist_last_session: self.auto_persist_last_session = True if self.certificates_to_request is None: - # TODO: Replace with actual RequestedCertificateSet - self.certificates_to_request = { - 'certifiers': [], - 'certificate_types': {} - } + try: + from .requested_certificate_set import RequestedCertificateSet, RequestedCertificateTypeIDAndFieldList + self.certificates_to_request = RequestedCertificateSet( + certifiers=[], + certificate_types=RequestedCertificateTypeIDAndFieldList(), + ) + except Exception: + # Fallback to a minimal dict structure if imports are unavailable + self.certificates_to_request = { + 'certifiers': [], + 'certificate_types': {} + } # Start the peer (register handlers, etc.) try: self.start() except Exception as e: self.logger.warning(f"Failed to start peer: {e}") + self.FAIL_TO_GET_IDENTIFY_KEY = "failed to get identity key" + self.AUTH_MESSAGE_SIGNATURE = "auth message signature" + self.SESSION_NOT_FOUND = "Session not found" + self.FAILED_TO_GET_AUTHENTICATED_SESSION = "failed to get authenticated session" def start(self): """ @@ -83,206 +95,204 @@ def on_data(ctx, message): print("[Peer DEBUG] transport handler registration ok") # --- Canonicalization helpers for signing/verification --- + def _rcs_hex_certifiers(self, raw_list: Any) -> list: + certs: list = [] + for pk in raw_list or []: + try: + if hasattr(pk, 'hex') and callable(getattr(pk, 'hex')): + certs.append(pk.hex()) + elif isinstance(pk, (bytes, bytearray)): + certs.append(bytes(pk).hex()) + else: + certs.append(str(pk)) + except Exception: + certs.append(str(pk)) + return certs + + def _rcs_key_to_b64(self, key: Any) -> Optional[str]: + import base64 as _b64 + if isinstance(key, (bytes, bytearray)): + b = bytes(key) + return _b64.b64encode(b).decode('ascii') if len(b) == 32 else None + ks = str(key) + try: + dec = _b64.b64decode(ks) + if len(dec) == 32: + return _b64.b64encode(dec).decode('ascii') + except Exception: + pass + try: + b = bytes.fromhex(ks) + if len(b) == 32: + return _b64.b64encode(b).decode('ascii') + except Exception: + pass + return None + + def _rcs_types_dict_from_requested(self, req: Any) -> dict: + if isinstance(req, dict): + return ( + req.get('certificate_types') + or req.get('certificateTypes') + or req.get('types') + or {} + ) + return {} + + def _rcs_from_object(self, requested_obj: Any) -> tuple[list, dict]: + certifiers = self._rcs_hex_certifiers(getattr(requested_obj, 'certifiers', []) or []) + mapping = getattr(getattr(requested_obj, 'certificate_types', None), 'mapping', {}) or {} + types_b64: dict = {} + for k, v in mapping.items(): + k_b64 = self._rcs_key_to_b64(k) + if k_b64 is None: + continue + types_b64[k_b64] = list(v or []) + return certifiers, types_b64 + + def _rcs_from_dict(self, requested_dict: dict) -> tuple[list, dict]: + certifiers = self._rcs_hex_certifiers(requested_dict.get('certifiers', [])) + types_b64: dict = {} + for k, v in self._rcs_types_dict_from_requested(requested_dict).items(): + k_b64 = self._rcs_key_to_b64(k) + if k_b64 is None: + continue + types_b64[k_b64] = list(v or []) + return certifiers, types_b64 + def _canonicalize_requested_certificates(self, requested: Any) -> dict: try: from .requested_certificate_set import RequestedCertificateSet except Exception: RequestedCertificateSet = None # type: ignore - result: dict = {"certifiers": [], "certificateTypes": {}} + if requested is None: - return result + return {"certifiers": [], "certificateTypes": {}} + try: - # Normalize certifiers - certifiers: list = [] + certifiers: list + types_b64: dict + if RequestedCertificateSet is not None and isinstance(requested, RequestedCertificateSet): - for pk in requested.certifiers: - try: - certifiers.append(pk.hex()) - except Exception: - certifiers.append(str(pk)) - mapping = getattr(requested.certificate_types, 'mapping', {}) or {} - for k, v in mapping.items(): - try: - import base64 as _b64 - k_b64 = _b64.b64encode(k).decode('ascii') if isinstance(k, (bytes, bytearray)) else str(k) - except Exception: - k_b64 = str(k) - result["certificateTypes"][k_b64] = sorted(list(v or [])) + certifiers, types_b64 = self._rcs_from_object(requested) elif isinstance(requested, dict): - # Expect 'certifiers' as list of hex strings or objects with hex - for pk in requested.get('certifiers', []): - try: - certifiers.append(pk.hex()) - except Exception: - certifiers.append(str(pk)) - types_dict = ( - requested.get('certificate_types') - or requested.get('certificateTypes') - or requested.get('types') - or {} - ) - # Canonicalize keys to base64 for deterministic cross-language signatures - import base64 as _b64 - for k, v in types_dict.items(): - k_b64: str - if isinstance(k, (bytes, bytearray)): - if len(k) != 32: - continue - k_b64 = _b64.b64encode(bytes(k)).decode('ascii') - else: - ks = str(k) - try: - # If already base64 of length 32 bytes when decoded, keep as-is - dec = _b64.b64decode(ks) - if len(dec) == 32: - k_b64 = _b64.b64encode(dec).decode('ascii') - else: - # Try hex - b = bytes.fromhex(ks) - if len(b) != 32: - continue - k_b64 = _b64.b64encode(b).decode('ascii') - except Exception: - try: - b = bytes.fromhex(ks) - if len(b) != 32: - continue - k_b64 = _b64.b64encode(b).decode('ascii') - except Exception: - # Unknown format; skip - continue - result["certificateTypes"][k_b64] = sorted(list(v or [])) - result["certifiers"] = sorted(certifiers) + certifiers, types_b64 = self._rcs_from_dict(requested) + else: + certifiers, types_b64 = [], {} + + # Sort outputs deterministically + sorted_types = {k: sorted(list(v or [])) for k, v in types_b64.items()} + return {"certifiers": sorted(certifiers), "certificateTypes": sorted_types} except Exception: - # Fallback to string-dump to avoid raising return {"certifiers": [], "certificateTypes": {}} - return result - - def _canonicalize_certificates_payload(self, certs: Any) -> list: - import base64 as _b64 - canonical: list = [] - if not certs: - return canonical - def _to_b64_32(value: Any) -> Optional[str]: - if value is None: - return None - # If already bytes, expect 32 bytes - if isinstance(value, (bytes, bytearray)): - b = bytes(value) + # --- Helpers for certificate payload canonicalization --- + def _b64_32(self, value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, (bytes, bytearray)): + b = bytes(value) + return base64.b64encode(b).decode('ascii') if len(b) == 32 else None + if isinstance(value, str): + s = value + try: + dec = base64.b64decode(s) + if len(dec) == 32: + return base64.b64encode(dec).decode('ascii') + except Exception: + pass + try: + b = bytes.fromhex(s) if len(b) == 32: - return _b64.b64encode(b).decode('ascii') - return None - # If has .encode (string) - if isinstance(value, str): - s = value - # Try base64 first - try: - dec = _b64.b64decode(s) - if len(dec) == 32: - return _b64.b64encode(dec).decode('ascii') - except Exception: - pass - # Try hex - try: - b = bytes.fromhex(s) - if len(b) == 32: - return _b64.b64encode(b).decode('ascii') - except Exception: - pass - return None + return base64.b64encode(b).decode('ascii') + except Exception: + pass return None + return None - def _pubkey_to_hex(value: Any) -> Optional[str]: - if value is None: + def _pubkey_to_hex(self, value: Any) -> Optional[str]: + if value is None: + return None + if hasattr(value, 'hex') and callable(getattr(value, 'hex')): + try: + return value.hex() + except Exception: return None - # PublicKey object with hex() method - if hasattr(value, 'hex') and callable(getattr(value, 'hex')): - try: - return value.hex() - except Exception: - pass - # bytes -> hex - if isinstance(value, (bytes, bytearray)): - return bytes(value).hex() - # string: try base64(33) to hex, else assume already hex - if isinstance(value, str): - s = value - try: - dec = _b64.b64decode(s) - # Compressed pubkey typically 33 bytes - if len(dec) in (33, 65): - return dec.hex() - except Exception: - pass - # Heuristic: if looks like hex - try: - _ = bytes.fromhex(s) - return s.lower() - except Exception: - pass + if isinstance(value, (bytes, bytearray)): + return bytes(value).hex() + if isinstance(value, str): + s = value + try: + dec = base64.b64decode(s) + if len(dec) in (33, 65): + return dec.hex() + except Exception: + pass + try: + _ = bytes.fromhex(s) + return s.lower() + except Exception: return s - return str(value) + return str(value) + def _normalize_revocation_outpoint(self, rev: Any) -> Optional[dict]: + if isinstance(rev, dict): + return {"txid": rev.get('txid'), "index": rev.get('index')} + if rev is not None and hasattr(rev, 'txid') and hasattr(rev, 'index'): + return {"txid": getattr(rev, 'txid', None), "index": getattr(rev, 'index', None)} + return None + + def _get_base_keyring_signature(self, entry: Any): + if isinstance(entry, dict): + return entry.get('certificate', entry), (entry.get('keyring', {}) or {}), entry.get('signature') + return ( + getattr(entry, 'certificate', entry), + getattr(entry, 'keyring', {}) or {}, + getattr(entry, 'signature', None), + ) + + def _extract_base_fields(self, base: Any): + if isinstance(base, dict): + return ( + base.get('type'), + base.get('serialNumber') or base.get('serial_number'), + base.get('subject'), + base.get('certifier'), + base.get('revocationOutpoint') or base.get('revocation_outpoint'), + base.get('fields', {}) or {}, + ) + return ( + getattr(base, 'type', None), + getattr(base, 'serial_number', None), + getattr(base, 'subject', None), + getattr(base, 'certifier', None), + getattr(base, 'revocation_outpoint', None), + getattr(base, 'fields', {}) or {}, + ) + + def _canonicalize_cert_entry(self, entry: Any) -> dict: + base, keyring, signature = self._get_base_keyring_signature(entry) + cert_type_raw, serial_raw, subject_raw, certifier_raw, rev, fields = self._extract_base_fields(base) + return { + "type": self._b64_32(cert_type_raw) or cert_type_raw, + "serialNumber": self._b64_32(serial_raw) or serial_raw, + "subject": self._pubkey_to_hex(subject_raw), + "certifier": self._pubkey_to_hex(certifier_raw), + "revocationOutpoint": self._normalize_revocation_outpoint(rev), + "fields": fields, + "keyring": keyring, + "signature": (base64.b64encode(signature).decode('ascii') if isinstance(signature, (bytes, bytearray)) else signature), + } + + def _canonicalize_certificates_payload(self, certs: Any) -> list: + canonical: list = [] + if not certs: + return canonical for c in certs: try: - # Support object or dict inputs, and nested {"certificate": ...} - base = None - keyring = {} - signature = None - if isinstance(c, dict): - base = c.get('certificate', c) - keyring = c.get('keyring', {}) or {} - signature = c.get('signature') - else: - base = getattr(c, 'certificate', c) - keyring = getattr(c, 'keyring', {}) or {} - signature = getattr(c, 'signature', None) - - # Extract fields from base certificate - if isinstance(base, dict): - cert_type_raw = base.get('type') - serial_raw = base.get('serialNumber') or base.get('serial_number') - subject_raw = base.get('subject') - certifier_raw = base.get('certifier') - rev = base.get('revocationOutpoint') or base.get('revocation_outpoint') - fields = base.get('fields', {}) or {} - else: - cert_type_raw = getattr(base, 'type', None) - serial_raw = getattr(base, 'serial_number', None) - subject_raw = getattr(base, 'subject', None) - certifier_raw = getattr(base, 'certifier', None) - rev = getattr(base, 'revocation_outpoint', None) - fields = getattr(base, 'fields', {}) or {} - - # Normalize primitives - cert_type_b64 = _to_b64_32(cert_type_raw) or cert_type_raw - serial_b64 = _to_b64_32(serial_raw) or serial_raw - subject_hex = _pubkey_to_hex(subject_raw) - certifier_hex = _pubkey_to_hex(certifier_raw) - rev_dict = None - if isinstance(rev, dict): - rev_dict = {"txid": rev.get('txid'), "index": rev.get('index')} - elif rev is not None and hasattr(rev, 'txid') and hasattr(rev, 'index'): - rev_dict = {"txid": getattr(rev, 'txid', None), "index": getattr(rev, 'index', None)} - sig_b64 = _b64.b64encode(signature).decode('ascii') if isinstance(signature, (bytes, bytearray)) else signature - - # Deterministic field order ensured by JSON sort_keys on serialization, but field list order stable - canonical.append({ - "type": cert_type_b64, - "serialNumber": serial_b64, - "subject": subject_hex, - "certifier": certifier_hex, - "revocationOutpoint": rev_dict, - "fields": fields, - "keyring": keyring, - "signature": sig_b64, - }) + canonical.append(self._canonicalize_cert_entry(c)) except Exception: - # Best effort: stringify canonical.append(str(c)) - - # Sort deterministically by (type, serialNumber) try: canonical.sort(key=lambda x: (x.get('type', '') or '', x.get('serialNumber', '') or '')) except Exception: @@ -325,10 +335,48 @@ def handle_initial_request(self, ctx: Any, message: Any, sender_public_key: Any) initial_nonce = getattr(message, 'initial_nonce', None) if not initial_nonce: return Exception("Invalid nonce") - import os, base64, time - our_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + + # 1) Generate our session nonce + our_nonce = self._generate_session_nonce(ctx) if self._debug: print(f"[Peer DEBUG] handle_initial_request: our_nonce={our_nonce}, peer_nonce={initial_nonce}") + + # 2) Create and store session (auth status may be downgraded if we plan to request certs) + session = self._create_session_for_initial(sender_public_key, initial_nonce, our_nonce) + if self._debug: + print(f"[Peer DEBUG] handle_initial_request: session added, nonce={session.session_nonce}") + + # 3) Get our identity key + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): + return Exception(self.FAIL_TO_GET_IDENTIFY_KEY) + + # 4) Acquire any requested certificates from the peer's initial request + certs = [] + requested_certs = getattr(message, 'requested_certificates', None) + if requested_certs is not None: + certs = self._acquire_requested_certs_for_initial(ctx, requested_certs, identity_key_result) + + # 5) Build initial response and sign it + response_err = self._send_initial_response(ctx, message, identity_key_result, initial_nonce, session, certs) + if response_err is not None: + return response_err + + if self._debug: + print("[Peer DEBUG] handle_initial_request: response sent") + return None + + def _generate_session_nonce(self, ctx: Any) -> str: + import base64 + try: + from .utils import create_nonce + return create_nonce(self.wallet, {'type': 1}, ctx) + except Exception: + import os + return base64.b64encode(os.urandom(32)).decode('ascii') + + def _create_session_for_initial(self, sender_public_key: Any, initial_nonce: str, our_nonce: str): + import time from .peer_session import PeerSession session = PeerSession( is_authenticated=True, @@ -337,67 +385,65 @@ def handle_initial_request(self, ctx: Any, message: Any, sender_public_key: Any) peer_identity_key=sender_public_key, last_update=int(time.time() * 1000) ) + # If we plan to request certificates, mark unauthenticated until received req_certs = getattr(self, 'certificates_to_request', None) if req_certs is not None and hasattr(req_certs, 'certificate_types') and len(req_certs.certificate_types) > 0: session.is_authenticated = False self.session_manager.add_session(session) - if self._debug: - print(f"[Peer DEBUG] handle_initial_request: session added, nonce={session.session_nonce}") - identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") - if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): - return Exception("failed to get identity key") - certs = [] - requested_certs = getattr(message, 'requested_certificates', None) - if requested_certs is not None: + return session + + def _acquire_requested_certs_for_initial(self, ctx: Any, requested_certs: Any, identity_key_result: Any) -> list: + import base64 + certs: list = [] + try: from .verifiable_certificate import VerifiableCertificate from .certificate import Certificate - from .requested_certificate_set import RequestedCertificateSet - try: - # Obtain from certificate DB or wallet - for cert_type, fields in requested_certs.certificate_types.items(): - args = { - 'cert_type': base64.b64encode(cert_type).decode(), - 'fields': fields, - 'subject': identity_key_result.public_key.hex(), - 'certifiers': [pk.hex() for pk in requested_certs.certifiers], - } - # Acquire certificate from wallet (use acquire_certificate or list_certificates as needed) - cert_result = self.wallet.acquire_certificate(ctx, args, "auth-peer") - # If the result is a list, wrap all, otherwise just one - if isinstance(cert_result, list): - for cert in cert_result: - if isinstance(cert, Certificate): - certs.append(VerifiableCertificate(cert)) - elif isinstance(cert_result, Certificate): - certs.append(VerifiableCertificate(cert_result)) - except Exception as e: - self.logger.warning(f"Failed to acquire certificates: {e}") + # Obtain from certificate DB or wallet + for cert_type, fields in getattr(requested_certs, 'certificate_types', {} ).items(): + args = { + 'cert_type': base64.b64encode(cert_type).decode(), + 'fields': fields, + 'subject': identity_key_result.public_key.hex(), + 'certifiers': [pk.hex() for pk in getattr(requested_certs, 'certifiers', [])], + } + cert_result = self.wallet.acquire_certificate(ctx, args, "auth-peer") + if isinstance(cert_result, list): + for cert in cert_result: + if isinstance(cert, Certificate): + certs.append(VerifiableCertificate(cert)) + elif isinstance(cert_result, Certificate): + certs.append(VerifiableCertificate(cert_result)) + except Exception as e: + self.logger.warning(f"Failed to acquire certificates: {e}") + return certs + + def _send_initial_response(self, ctx: Any, message: Any, identity_key_result: Any, initial_nonce: str, session: Any, certs: list) -> Optional[Exception]: + import base64 from .auth_message import AuthMessage response = AuthMessage( version="0.1", message_type="initialResponse", identity_key=identity_key_result.public_key, - nonce=our_nonce, + nonce=session.session_nonce, your_nonce=initial_nonce, initial_nonce=session.session_nonce, certificates=certs ) try: - initial_nonce_bytes = base64.b64decode(initial_nonce) - session_nonce_bytes = base64.b64decode(session.session_nonce) + sig_data = self._compute_initial_sig_data(initial_nonce, session.session_nonce) except Exception as e: return Exception(f"failed to decode nonce: {e}") - sig_data = initial_nonce_bytes + session_nonce_bytes + sig_result = self.wallet.create_signature(ctx, { 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{initial_nonce} {session.session_nonce}", 'counterparty': { 'type': 3, - 'counterparty': message.identity_key if hasattr(message, 'identity_key') else None + 'counterparty': getattr(message, 'identity_key', None) } }, 'data': sig_data @@ -408,10 +454,143 @@ def handle_initial_request(self, ctx: Any, message: Any, sender_public_key: Any) err = self.transport.send(ctx, response) if err is not None: return Exception(f"failed to send initial response: {err}") - if self._debug: - print("[Peer DEBUG] handle_initial_request: response sent") return None + def _compute_initial_sig_data(self, initial_nonce: str, session_nonce: str) -> bytes: + import base64 + initial_nonce_bytes = base64.b64decode(initial_nonce) + session_nonce_bytes = base64.b64decode(session_nonce) + return initial_nonce_bytes + session_nonce_bytes + + # --- Helpers for certificate validation --- + def _is_rcs_like(self, obj: Any) -> bool: + return hasattr(obj, 'certifiers') and hasattr(obj, 'certificate_types') + + def _extract_certifiers_from_req(self, req: Any) -> list: + if self._is_rcs_like(req): + return list(getattr(req, 'certifiers', []) or []) + if isinstance(req, dict): + return req.get('certifiers') or req.get('Certifiers') or [] + return [] + + def _extract_types_map_from_req(self, req: Any) -> Dict[bytes, list]: + result: Dict[bytes, list] = {} + if self._is_rcs_like(req): + raw = getattr(getattr(req, 'certificate_types', None), 'mapping', {}) or {} + elif isinstance(req, dict): + raw = req.get('certificate_types') or req.get('certificateTypes') or req.get('types') or {} + else: + raw = {} + for k, v in raw.items(): + key_b = bytes(k) if isinstance(k, (bytes, bytearray)) else self._decode_type_bytes(k) + if key_b is not None: + result[key_b] = list(v or []) + return result + + def _normalize_requested_certificate_constraints(self, req: Any): + try: + certifiers = self._extract_certifiers_from_req(req) + types_map = self._extract_types_map_from_req(req) + return certifiers, types_map + except Exception: + return [], {} + + def _decode_type_bytes(self, val: Any) -> Optional[bytes]: + if isinstance(val, (bytes, bytearray)): + return bytes(val) + if isinstance(val, str): + try: + import base64 as _b64 + return _b64.b64decode(val) + except Exception: + try: + return bytes.fromhex(val) + except Exception: + return None + return None + + # Granular validators for a single certificate + def _get_base_cert(self, cert: Any) -> Any: + return getattr(cert, 'certificate', cert) + + def _has_valid_signature(self, ctx: Any, cert: Any) -> bool: + try: + if hasattr(cert, 'verify') and not cert.verify(ctx): + self.logger.warning(f"Certificate signature invalid: {cert}") + return False + except Exception as e: + self.logger.warning(f"Certificate signature verification error: {e}") + return False + return True + + def _subject_matches_expected(self, expected_subject: Any, base_cert: Any) -> bool: + if expected_subject is None: + return True + try: + subj_hex = self._pubkey_to_hex(getattr(base_cert, 'subject', None)) + exp_hex = self._pubkey_to_hex(expected_subject) + if subj_hex is None or exp_hex is None or subj_hex != exp_hex: + self.logger.warning("Certificate subject does not match the expected identity key") + return False + return True + except Exception as e: + self.logger.warning(f"Subject comparison failed: {e}") + return False + + def _is_certifier_allowed(self, allowed_certifier_hexes: Set[str], base_cert: Any) -> bool: + if not allowed_certifier_hexes: + return True + try: + cert_hex = self._pubkey_to_hex(getattr(base_cert, 'certifier', None)) + if cert_hex is None or cert_hex.lower() not in allowed_certifier_hexes: + self.logger.warning("Certificate has unrequested certifier") + return False + return True + except Exception as e: + self.logger.warning(f"Certifier check failed: {e}") + return False + + def _type_and_fields_valid(self, requested_types: Dict[bytes, list], base_cert: Any) -> bool: + if not requested_types: + return True + try: + cert_type_bytes = self._decode_type_bytes(getattr(base_cert, 'type', None)) + if not cert_type_bytes: + self.logger.warning("Invalid certificate type encoding") + return False + if cert_type_bytes not in requested_types: + self.logger.warning("Certificate type was not requested") + return False + required_fields = requested_types.get(cert_type_bytes, []) + cert_fields = getattr(base_cert, 'fields', {}) or {} + for field in required_fields: + if field not in cert_fields: + self.logger.warning(f"Certificate missing required field: {field}") + return False + return True + except Exception as e: + self.logger.warning(f"Type/fields validation failed: {e}") + return False + + def _validate_single_certificate( + self, + ctx: Any, + cert: Any, + expected_subject: Any, + allowed_certifier_hexes: Set[str], + requested_types: Dict[bytes, list], + ) -> bool: + base_cert = self._get_base_cert(cert) + if not self._has_valid_signature(ctx, cert): + return False + if not self._subject_matches_expected(expected_subject, base_cert): + return False + if not self._is_certifier_allowed(allowed_certifier_hexes, base_cert): + return False + if not self._type_and_fields_valid(requested_types, base_cert): + return False + return True + def _validate_certificates(self, ctx: Any, certs: list, requested_certs: Any = None, expected_subject: Any = None) -> bool: """ Validate VerifiableCertificates against a RequestedCertificateSet or dict. @@ -420,121 +599,16 @@ def _validate_certificates(self, ctx: Any, certs: list, requested_certs: Any = N - Ensures type is requested and required fields are present (if provided) - Ensures subject matches expected_subject (if provided) """ - from .requested_certificate_set import RequestedCertificateSet valid = True - - def _normalize_requested(req: Any): - certifiers = [] - type_map = {} - try: - if isinstance(req, RequestedCertificateSet): - certifiers = list(getattr(req, 'certifiers', []) or []) - mapping = getattr(getattr(req, 'certificate_types', None), 'mapping', {}) or {} - type_map = dict(mapping) - elif isinstance(req, dict): - certifiers = req.get('certifiers') or req.get('Certifiers') or [] - types_dict = req.get('certificate_types') or req.get('certificateTypes') or req.get('types') or {} - for k, v in types_dict.items(): - if isinstance(k, (bytes, bytearray)): - key_b = bytes(k) - else: - try: - key_b = base64.b64decode(k) - except Exception: - continue - type_map[key_b] = list(v or []) - except Exception: - pass - return certifiers, type_map - - allowed_certifiers, requested_types = _normalize_requested(requested_certs) - # Normalize allowed certifiers to hex strings for comparison + allowed_certifiers, requested_types = self._normalize_requested_certificate_constraints(requested_certs) allowed_certifier_hexes: Set[str] = set() for c in allowed_certifiers or []: - try: - if hasattr(c, 'hex'): - allowed_certifier_hexes.add(c.hex()) - elif isinstance(c, (bytes, bytearray)): - allowed_certifier_hexes.add(bytes(c).hex()) - elif isinstance(c, str): - # accept hex strings - int(c, 16) - allowed_certifier_hexes.add(c.lower()) - except Exception: - continue + hx = self._pubkey_to_hex(c) + if isinstance(hx, str): + allowed_certifier_hexes.add(hx.lower()) for cert in certs: - try: - base_cert = getattr(cert, 'certificate', cert) - # Signature verification - if hasattr(cert, 'verify') and not cert.verify(ctx): - self.logger.warning(f"Certificate signature invalid: {cert}") - valid = False - continue - # Subject verification - if expected_subject is not None: - subj = getattr(base_cert, 'subject', None) - try: - subj_hex = subj.hex() if hasattr(subj, 'hex') else None - exp_hex = expected_subject.hex() if hasattr(expected_subject, 'hex') else None - if subj_hex is None or exp_hex is None or subj_hex != exp_hex: - self.logger.warning("Certificate subject does not match the expected identity key") - valid = False - continue - except Exception: - self.logger.warning("Failed to compare certificate subject with expected identity key") - valid = False - continue - # Certifier verification - if allowed_certifier_hexes: - certifier_val = getattr(base_cert, 'certifier', None) - try: - if hasattr(certifier_val, 'hex'): - cert_hex = certifier_val.hex() - elif isinstance(certifier_val, (bytes, bytearray)): - cert_hex = bytes(certifier_val).hex() - else: - cert_hex = str(certifier_val) - except Exception: - cert_hex = None - if cert_hex is None or cert_hex.lower() not in allowed_certifier_hexes: - self.logger.warning("Certificate has unrequested certifier") - valid = False - continue - # Type / fields verification - if requested_types: - cert_type_val = getattr(base_cert, 'type', None) - # Accept base64/hex/bytes - cert_type_bytes = None - if isinstance(cert_type_val, (bytes, bytearray)): - cert_type_bytes = bytes(cert_type_val) - elif isinstance(cert_type_val, str): - try: - b = base64.b64decode(cert_type_val) - cert_type_bytes = b - except Exception: - try: - b = bytes.fromhex(cert_type_val) - cert_type_bytes = b - except Exception: - cert_type_bytes = None - if not cert_type_bytes: - self.logger.warning("Invalid certificate type encoding") - valid = False - continue - if cert_type_bytes not in requested_types: - self.logger.warning("Certificate type was not requested") - valid = False - continue - required_fields = requested_types.get(cert_type_bytes, []) - cert_fields = getattr(base_cert, 'fields', {}) or {} - for field in required_fields: - if field not in cert_fields: - self.logger.warning(f"Certificate missing required field: {field}") - valid = False - break - except Exception as e: - self.logger.warning(f"Certificate validation error: {e}") + if not self._validate_single_certificate(ctx, cert, expected_subject, allowed_certifier_hexes, requested_types): valid = False return valid @@ -544,16 +618,29 @@ def handle_initial_response(self, ctx: Any, message: Any, sender_public_key: Any """ if self._debug: print("[Peer DEBUG] handle_initial_response: begin") + session = self._retrieve_initial_response_session(sender_public_key, message) + if session is None: + return Exception(self.SESSION_NOT_FOUND) + + err = self._verify_and_update_session_from_initial_response(ctx, message, session) + if err is not None: + return err + + self._process_initial_response_certificates(ctx, message, sender_public_key) + self._notify_initial_response_waiters(session, message) + self._handle_requested_certificates_from_peer_message(ctx, message, sender_public_key, source_label="initialResponse") + return None + + def _retrieve_initial_response_session(self, sender_public_key: Any, message: Any) -> Optional[Any]: session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None if session is None: - # Fallback: try to match by our original initial nonce carried in your_nonce your_nonce = getattr(message, 'your_nonce', None) if your_nonce: session = self.session_manager.get_session(your_nonce) - if session is None: - return Exception("Session not found") + return session + + def _verify_and_update_session_from_initial_response(self, ctx: Any, message: Any, session: Any) -> Optional[Exception]: try: - # Reconstruct signature data in the same order as signer (request.initial_nonce + response.session_nonce) client_initial_bytes = base64.b64decode(getattr(message, 'your_nonce', '')) server_session_bytes = base64.b64decode(getattr(message, 'initial_nonce', '')) except Exception as e: @@ -564,7 +651,7 @@ def handle_initial_response(self, ctx: Any, message: Any, sender_public_key: Any 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{getattr(message, 'your_nonce', '')} {getattr(message, 'initial_nonce', '')}", 'counterparty': { @@ -586,42 +673,103 @@ def handle_initial_response(self, ctx: Any, message: Any, sender_public_key: Any session.last_update = int(time.time() * 1000) self.session_manager.update_session(session) self.last_interacted_with_peer = getattr(message, 'identity_key', None) - # Certificate verification logic + return None + + def _process_initial_response_certificates(self, ctx: Any, message: Any, sender_public_key: Any) -> None: certs = getattr(message, 'certificates', []) - if certs: - # Strict verification: match against requested set and sender's identity_key - valid = self._validate_certificates( - ctx, - certs, - getattr(self, 'certificates_to_request', None), - expected_subject=getattr(message, 'identity_key', None), - ) - if not valid: - self.logger.warning("Invalid certificates in initial response") - for callback in self.on_certificate_received_callbacks.values(): - try: - callback(sender_public_key, certs) - except Exception as e: - self.logger.warning(f"Certificate received callback error: {e}") - # Notify any waiting initial-response callbacks registered during initiate_handshake + if not certs: + return + valid = self._validate_certificates( + ctx, + certs, + getattr(self, 'certificates_to_request', None), + expected_subject=getattr(message, 'identity_key', None), + ) + if not valid: + self.logger.warning("Invalid certificates in initial response") + for callback in self.on_certificate_received_callbacks.values(): + try: + callback(sender_public_key, certs) + except Exception as e: + self.logger.warning(f"Certificate received callback error: {e}") + + def _notify_initial_response_waiters(self, session: Any, message: Any) -> None: try: to_delete = None for cb_id, info in self.on_initial_response_received_callbacks.items(): if info.get('session_nonce') == session.session_nonce: - # Prefer to pass the peer's nonce to the callback peer_nonce = session.peer_nonce or getattr(message, 'initial_nonce', None) + to_delete = cb_id try: info.get('callback')(peer_nonce) - finally: - to_delete = cb_id - break + except Exception as e: + self.logger.warning(f"Initial response callback execution error: {e}") + break if to_delete is not None: del self.on_initial_response_received_callbacks[to_delete] except Exception as e: self.logger.warning(f"Initial response callback error: {e}") - # TODO: Handle requested certificates from peer if present - return None + def _handle_requested_certificates_from_peer_message(self, ctx: Any, message: Any, sender_public_key: Any, source_label: str = "") -> None: + try: + req_from_peer = getattr(message, 'requested_certificates', None) + if not self._has_requested_certificates(req_from_peer): + return + + if self._try_callbacks_for_requested_certs(ctx, sender_public_key, req_from_peer, source_label): + return + + self._auto_reply_with_requested_certs(ctx, message, sender_public_key, req_from_peer) + except Exception as e: + self.logger.warning(f"Requested certificates processing error: {e}") + + def _has_requested_certificates(self, req_from_peer: Any) -> bool: + if req_from_peer is None: + return False + if hasattr(req_from_peer, 'certifiers') and getattr(req_from_peer, 'certifiers'): + return True + if isinstance(req_from_peer, dict): + return bool( + req_from_peer.get('certifiers') + or req_from_peer.get('certificate_types') + or req_from_peer.get('certificateTypes') + or req_from_peer.get('types') + ) + return False + + def _try_callbacks_for_requested_certs(self, ctx: Any, sender_public_key: Any, req_from_peer: Any, source_label: str) -> bool: + if not self.on_certificate_request_received_callbacks: + return False + for cb in tuple(self.on_certificate_request_received_callbacks.values()): + try: + result = cb(sender_public_key, req_from_peer) + if result: + err = self.send_certificate_response(ctx, sender_public_key, result) + if err is None: + return True + except Exception as e: + self.logger.warning(f"Certificate request callback error ({source_label} handling): {e}") + return False + + def _auto_reply_with_requested_certs(self, ctx: Any, message: Any, sender_public_key: Any, req_from_peer: Any) -> None: + try: + canonical_req = self._canonicalize_requested_certificates(req_from_peer) + req_for_utils = { + 'certifiers': canonical_req.get('certifiers', []), + 'types': canonical_req.get('certificateTypes', {}) + } + from .utils import get_verifiable_certificates + verifiable = get_verifiable_certificates( + self.wallet, + req_for_utils, + getattr(message, 'identity_key', None) + ) + if verifiable is not None: + _err = self.send_certificate_response(ctx, sender_public_key, verifiable) + if _err is not None: + self.logger.warning(f"Failed to send auto certificate response: {_err}") + except Exception as e: + self.logger.warning(f"Auto certificate response error: {e}") def handle_certificate_request(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: """ @@ -631,17 +779,38 @@ def handle_certificate_request(self, ctx: Any, message: Any, sender_public_key: print("[Peer DEBUG] handle_certificate_request: begin") session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None if session is None: - return Exception("Session not found") - # --- Signature verification logic implementation --- + return Exception(self.SESSION_NOT_FOUND) + requested = getattr(message, 'requested_certificates', {}) canonical_req = self._canonicalize_requested_certificates(requested) + err = self._verify_certificate_request_signature(ctx, message, session, sender_public_key, canonical_req) + if err is not None: + return err + + self._touch_session(session) + + certs_to_send = self._invoke_cert_request_callbacks(sender_public_key, requested) + if certs_to_send is None: + subject_hex = self._get_identity_subject_hex(ctx) + if subject_hex is None: + return Exception("failed to get identity key for certificate response") + certs_to_send = self._auto_acquire_certificates_for_request(ctx, canonical_req, subject_hex) + + if self._debug: + print(f"[Peer DEBUG] handle_certificate_request: sending response, certs={len(certs_to_send or [])}") + err = self.send_certificate_response(ctx, sender_public_key, certs_to_send or []) + if err is not None: + return Exception(f"failed to send certificate response: {err}") + return None + + def _verify_certificate_request_signature(self, ctx: Any, message: Any, session: Any, sender_public_key: Any, canonical_req: dict) -> Optional[Exception]: cert_request_data = self._serialize_for_signature(canonical_req) signature = getattr(message, 'signature', None) verify_result = self.wallet.verify_signature(ctx, { 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{getattr(message, 'nonce', '')} {session.session_nonce}", 'counterparty': { @@ -656,68 +825,58 @@ def handle_certificate_request(self, ctx: Any, message: Any, sender_public_key: print(f"[Peer DEBUG] handle_certificate_request: verify_result={getattr(verify_result, 'valid', None)}") if verify_result is None or not getattr(verify_result, 'valid', False): return Exception("certificate request - invalid signature") + return None + + def _touch_session(self, session: Any) -> None: import time session.last_update = int(time.time() * 1000) self.session_manager.update_session(session) - # --- Response side implementation: callback -> acquire -> sign -> send --- - certs_to_send = None - # 1) Prioritize callbacks if any - if self.on_certificate_request_received_callbacks: - if self._debug: - print("[Peer DEBUG] handle_certificate_request: invoking request callbacks") - for cb in list(self.on_certificate_request_received_callbacks.values()): - try: - result = cb(sender_public_key, requested) - if result: - certs_to_send = result - break - except Exception as e: - self.logger.warning(f"Certificate request callback error: {e}") - # 2) Fallback: acquire from wallet/store - if certs_to_send is None: - if self._debug: - print("[Peer DEBUG] handle_certificate_request: fallback to wallet.acquire_certificate") - certs: list = [] + + def _invoke_cert_request_callbacks(self, sender_public_key: Any, requested: Any): + if not self.on_certificate_request_received_callbacks: + return None + if self._debug: + print("[Peer DEBUG] handle_certificate_request: invoking request callbacks") + for cb in tuple(self.on_certificate_request_received_callbacks.values()): try: - # Our identity key - identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") - subject_hex = getattr(getattr(identity_key_result, 'public_key', None), 'hex', lambda: None)() - if subject_hex is None: - raise RuntimeError("failed to get identity key for certificate response") - # Acquire certificates (RequestedCertificateSet compatible) - try: - from .requested_certificate_set import RequestedCertificateSet - except Exception: - RequestedCertificateSet = None # type: ignore - # Read from normalized canonical_req - certifiers_list = canonical_req.get('certifiers', []) - types_dict = canonical_req.get('certificateTypes', {}) - for cert_type_b64, fields in types_dict.items(): - args = { - 'cert_type': cert_type_b64, - 'fields': list(fields or []), - 'subject': subject_hex, - 'certifiers': list(certifiers_list or []), - } - try: - cert_result = self.wallet.acquire_certificate(ctx, args, "auth-peer") - except Exception: - cert_result = None - if isinstance(cert_result, list): - certs.extend(cert_result) - elif cert_result is not None: - certs.append(cert_result) + result = cb(sender_public_key, requested) + if result: + return result except Exception as e: - self.logger.warning(f"Failed to acquire certificates for response: {e}") - certs_to_send = certs - # 3) Send response - if self._debug: - print(f"[Peer DEBUG] handle_certificate_request: sending response, certs={len(certs_to_send or [])}") - err = self.send_certificate_response(ctx, sender_public_key, certs_to_send or []) - if err is not None: - return Exception(f"failed to send certificate response: {err}") + self.logger.warning(f"Certificate request callback error: {e}") return None + def _get_identity_subject_hex(self, ctx: Any) -> Optional[str]: + try: + identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") + return getattr(getattr(identity_key_result, 'public_key', None), 'hex', lambda: None)() + except Exception: + return None + + def _auto_acquire_certificates_for_request(self, ctx: Any, canonical_req: dict, subject_hex: str) -> list: + certs: list = [] + try: + certifiers_list = canonical_req.get('certifiers', []) + types_dict = canonical_req.get('certificateTypes', {}) + for cert_type_b64, fields in types_dict.items(): + args = { + 'cert_type': cert_type_b64, + 'fields': list(fields or []), + 'subject': subject_hex, + 'certifiers': list(certifiers_list or []), + } + try: + cert_result = self.wallet.acquire_certificate(ctx, args, "auth-peer") + except Exception: + cert_result = None + if isinstance(cert_result, list): + certs.extend(cert_result) + elif cert_result is not None: + certs.append(cert_result) + except Exception as e: + self.logger.warning(f"Failed to acquire certificates for response: {e}") + return certs + def handle_certificate_response(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: """ Processes a certificate response message. @@ -726,16 +885,29 @@ def handle_certificate_response(self, ctx: Any, message: Any, sender_public_key: print("[Peer DEBUG] handle_certificate_response: begin") session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None if session is None: - return Exception("Session not found") + return Exception(self.SESSION_NOT_FOUND) + certs = getattr(message, 'certificates', []) canonical_certs = self._canonicalize_certificates_payload(certs) cert_data = self._serialize_for_signature(canonical_certs) + + err = self._verify_certificate_response_signature(ctx, message, session, sender_public_key, cert_data) + if err is not None: + return err + + self._touch_session(session) + + self._process_certificate_response_certificates(ctx, message, sender_public_key) + self._handle_requested_certificates_from_peer_message(ctx, message, sender_public_key, source_label="certificateResponse") + return None + + def _verify_certificate_response_signature(self, ctx: Any, message: Any, session: Any, sender_public_key: Any, cert_data: bytes) -> Optional[Exception]: signature = getattr(message, 'signature', None) verify_result = self.wallet.verify_signature(ctx, { 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{getattr(message, 'nonce', '')} {session.session_nonce}", 'counterparty': { @@ -750,34 +922,53 @@ def handle_certificate_response(self, ctx: Any, message: Any, sender_public_key: print(f"[Peer DEBUG] handle_certificate_response: verify_result={getattr(verify_result, 'valid', None)}") if verify_result is None or not getattr(verify_result, 'valid', False): return Exception("certificate response - invalid signature") - import time - session.last_update = int(time.time() * 1000) - self.session_manager.update_session(session) - # Certificate verification logic - certs = getattr(message, 'certificates', []) - if certs: - valid = self._validate_certificates( - ctx, - certs, - getattr(self, 'certificates_to_request', None), - expected_subject=getattr(message, 'identity_key', None), - ) - if not valid: - self.logger.warning("Invalid certificates in certificate response") - for callback in self.on_certificate_received_callbacks.values(): - try: - callback(sender_public_key, certs) - except Exception as e: - self.logger.warning(f"Certificate callback error: {e}") return None + def _process_certificate_response_certificates(self, ctx: Any, message: Any, sender_public_key: Any) -> None: + certs = getattr(message, 'certificates', []) + if not certs: + return + valid = self._validate_certificates( + ctx, + certs, + getattr(self, 'certificates_to_request', None), + expected_subject=getattr(message, 'identity_key', None), + ) + if not valid: + self.logger.warning("Invalid certificates in certificate response") + for callback in self.on_certificate_received_callbacks.values(): + try: + callback(sender_public_key, certs) + except Exception as e: + self.logger.warning(f"Certificate callback error: {e}") + def handle_general_message(self, ctx: Any, message: Any, sender_public_key: Any) -> Optional[Exception]: """ Processes a general message. """ if self._debug: print("[Peer DEBUG] handle_general_message: begin") - # Optional: validate nonce for replay protection (non-fatal) + self._optionally_verify_nonce(ctx, message, sender_public_key) + if self._is_loopback_echo(ctx, sender_public_key): + return None + + session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None + if session is None: + return Exception(self.SESSION_NOT_FOUND) + + payload = getattr(message, 'payload', None) + data_to_verify = self._serialize_for_signature(payload) + err = self._verify_general_message_signature(ctx, message, session, sender_public_key, data_to_verify) + if err is not None: + return err + + self._touch_session(session) + if self.auto_persist_last_session: + self.last_interacted_with_peer = sender_public_key + self._dispatch_general_message_callbacks(sender_public_key, payload) + return None + + def _optionally_verify_nonce(self, ctx: Any, message: Any, sender_public_key: Any) -> None: try: from .utils import verify_nonce nonce = getattr(message, 'nonce', None) @@ -785,27 +976,24 @@ def handle_general_message(self, ctx: Any, message: Any, sender_public_key: Any) self.logger.warning("general message - nonce verification failed") except Exception: pass - # If this is a loopback of our own outbound message (test transport echoes), ignore gracefully + + def _is_loopback_echo(self, ctx: Any, sender_public_key: Any) -> bool: try: identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") if identity_key_result is not None and hasattr(identity_key_result, 'public_key') and sender_public_key is not None: if getattr(identity_key_result.public_key, 'hex', None) and getattr(sender_public_key, 'hex', None): - if identity_key_result.public_key.hex() == sender_public_key.hex(): - return None + return identity_key_result.public_key.hex() == sender_public_key.hex() except Exception: pass - session = self.session_manager.get_session(sender_public_key.hex()) if sender_public_key else None - if session is None: - return Exception("Session not found") - # --- Signature verification logic implementation --- + return False + + def _verify_general_message_signature(self, ctx: Any, message: Any, session: Any, sender_public_key: Any, data_to_verify: bytes) -> Optional[Exception]: signature = getattr(message, 'signature', None) - payload = getattr(message, 'payload', None) - data_to_verify = self._serialize_for_signature(payload) verify_result = self.wallet.verify_signature(ctx, { 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{getattr(message, 'nonce', '')} {session.session_nonce}", 'counterparty': { @@ -816,19 +1004,16 @@ def handle_general_message(self, ctx: Any, message: Any, sender_public_key: Any) 'data': data_to_verify, 'signature': signature }, "auth-peer") - if verify_result is None or not getattr(verify_result, 'valid', False): + if not getattr(verify_result, 'valid', False): return Exception("general message - invalid signature") - import time - session.last_update = int(time.time() * 1000) - self.session_manager.update_session(session) - if self.auto_persist_last_session: - self.last_interacted_with_peer = sender_public_key + return None + + def _dispatch_general_message_callbacks(self, sender_public_key: Any, payload: Any) -> None: for callback in self.on_general_message_received_callbacks.values(): try: callback(sender_public_key, payload) except Exception as e: self.logger.warning(f"General message callback error: {e}") - return None def expire_sessions(self, max_age_sec: int = 3600): """ @@ -858,8 +1043,29 @@ def expire_sessions(self, max_age_sec: int = 3600): print(f"[Peer DEBUG] expire_sessions: removed={before - after}, remaining={after}") def stop(self): - # TODO: Clean up any resources if needed - pass + """ + Stop the peer. Aligns with TS/Go behavior (no strict teardown required), + but performs best-effort cleanup: + - Deregister transport handler by installing a no-op + - Clear registered callbacks to avoid leaks + """ + if self._debug: + print("[Peer DEBUG] stop: begin") + # Best-effort: replace on_data with a no-op to stop receiving messages + try: + _ = self.transport.on_data(lambda _ctx, _msg: None) + except Exception: + pass + # Clear callback registries + try: + self.on_general_message_received_callbacks.clear() + self.on_certificate_received_callbacks.clear() + self.on_certificate_request_received_callbacks.clear() + self.on_initial_response_received_callbacks.clear() + except Exception: + pass + if self._debug: + print("[Peer DEBUG] stop: done") def listen_for_general_messages(self, callback: Callable) -> int: """ @@ -930,9 +1136,13 @@ def initiate_handshake(self, ctx: Any, peer_identity_key: Any, max_wait_time_ms: """ Starts the mutual authentication handshake with a peer. """ - # TODO: Replace with actual nonce creation logic - import os, base64, time - session_nonce = base64.b64encode(os.urandom(32)).decode('ascii') + import time + try: + from .utils import create_nonce + session_nonce = create_nonce(self.wallet, { 'type': 1 }, ctx) + except Exception: + import os, base64 + session_nonce = base64.b64encode(os.urandom(32)).decode('ascii') # Add a preliminary session entry (not yet authenticated) from .peer_session import PeerSession session = PeerSession( @@ -1009,12 +1219,12 @@ def to_peer(self, ctx: Any, message: bytes, identity_key: Optional[Any] = None, identity_key = self.last_interacted_with_peer peer_session = self.get_authenticated_session(ctx, identity_key, max_wait_time) if peer_session is None: - return Exception("failed to get authenticated session") + return Exception(self.FAILED_TO_GET_AUTHENTICATED_SESSION) import os, base64, time request_nonce = base64.b64encode(os.urandom(32)).decode('ascii') identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): - return Exception("failed to get identity key") + return Exception(self.FAIL_TO_GET_IDENTIFY_KEY) from .auth_message import AuthMessage general_message = AuthMessage( version="0.1", @@ -1030,7 +1240,7 @@ def to_peer(self, ctx: Any, message: bytes, identity_key: Optional[Any] = None, 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{request_nonce} {peer_session.peer_nonce}", 'counterparty': { @@ -1060,14 +1270,14 @@ def request_certificates(self, ctx: Any, identity_key: Any, certificate_requirem # Get or create an authenticated session peer_session = self.get_authenticated_session(ctx, identity_key, max_wait_time) if peer_session is None: - return Exception("failed to get authenticated session") + return Exception(self.FAILED_TO_GET_AUTHENTICATED_SESSION) # Create a nonce for this request import os, base64, time request_nonce = base64.b64encode(os.urandom(32)).decode('ascii') # Get identity key identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): - return Exception("failed to get identity key") + return Exception(self.FAIL_TO_GET_IDENTIFY_KEY) # Create certificate request message from .auth_message import AuthMessage cert_request = AuthMessage( @@ -1084,7 +1294,7 @@ def request_certificates(self, ctx: Any, identity_key: Any, certificate_requirem 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{request_nonce} {peer_session.peer_nonce}", 'counterparty': { @@ -1118,14 +1328,14 @@ def send_certificate_response(self, ctx: Any, identity_key: Any, certificates: A print(f"[Peer DEBUG] send_certificate_response: begin, certs_in={(len(certificates) if isinstance(certificates, list) else 'n/a')}") peer_session = self.get_authenticated_session(ctx, identity_key, 0) if peer_session is None: - return Exception("failed to get authenticated session") + return Exception(self.FAILED_TO_GET_AUTHENTICATED_SESSION) # Create a nonce for this response import os, base64, time response_nonce = base64.b64encode(os.urandom(32)).decode('ascii') # Get identity key identity_key_result = self.wallet.get_public_key(ctx, {'identityKey': True}, "auth-peer") if identity_key_result is None or not hasattr(identity_key_result, 'public_key'): - return Exception("failed to get identity key") + return Exception(self.FAIL_TO_GET_IDENTIFY_KEY) # Create certificate response message from .auth_message import AuthMessage cert_response = AuthMessage( @@ -1144,7 +1354,7 @@ def send_certificate_response(self, ctx: Any, identity_key: Any, certificates: A 'encryption_args': { 'protocol_id': { 'securityLevel': 2, - 'protocol': "auth message signature" + 'protocol': self.AUTH_MESSAGE_SIGNATURE }, 'key_id': f"{response_nonce} {peer_session.peer_nonce}", 'counterparty': { @@ -1252,7 +1462,6 @@ class PeerAuthError(Exception): class CertificateError(Exception): """Raised for certificate validation or issuance errors.""" - pass # --- 7. Serialization/deserialization helpers --- def serialize_data(self, data: Any) -> bytes: @@ -1304,10 +1513,24 @@ def expire_sessions(self, max_age_sec: int = 3600): # --- 9. Transport security stub (for extension) --- def secure_send(self, ctx: Any, message: Any) -> Optional[Exception]: """ - Send a message with additional security (encryption, MAC, etc.). - This is a stub for future extension. + Send a message with additional security. + + Parity with TS/Go SDKs: + - The current TS and Go implementations do not provide an additional + secure-send layer beyond signing. For protocol parity, we delegate + directly to the underlying transport. + + Forward-compatibility: + - If a transport exposes a `secure_send` method, prefer it. + - Otherwise, fall back to `send`. """ - # TODO: Implement encryption/MAC as needed + try: + secure = getattr(self.transport, 'secure_send', None) + if callable(secure): + return secure(ctx, message) + except Exception: + # Fall back to normal send on any error + pass return self.transport.send(ctx, message) # --- 10. Integration/E2E test utility --- diff --git a/bsv/auth/utils.py b/bsv/auth/utils.py index 07225f5..78a9b03 100644 --- a/bsv/auth/utils.py +++ b/bsv/auth/utils.py @@ -18,10 +18,10 @@ def verify_nonce(nonce: str, wallet: Any, counterparty: Any = None, ctx: Any = N # Prepare encryption_args for wallet.verify_hmac encryption_args = { 'protocol_id': { - 'securityLevel': 1, # Go版: SecurityLevelEveryApp = 1 + 'securityLevel': 1, # Go version: SecurityLevelEveryApp = 1 'protocol': 'server hmac' }, - 'key_id': data.decode('latin1'), # Go版: string(randomBytes) + 'key_id': data.decode('latin1'), # Go version: string(randomBytes) 'counterparty': counterparty } args = { @@ -48,10 +48,10 @@ def create_nonce(wallet: Any, counterparty: Any = None, ctx: Any = None) -> str: # Create an sha256 HMAC encryption_args = { 'protocol_id': { - 'securityLevel': 1, # Go版: SecurityLevelEveryApp = 1 + 'securityLevel': 1, # Go version: SecurityLevelEveryApp = 1 'protocol': 'server hmac' }, - 'key_id': first_half.decode('latin1'), # Go版: string(randomBytes) + 'key_id': first_half.decode('latin1'), # Go version: string(randomBytes) 'counterparty': counterparty } args = { @@ -102,100 +102,118 @@ def get_verifiable_certificates(wallet, requested_certificates, verifier_identit def validate_certificates(verifier_wallet, message, certificates_requested=None): """ - Validates and processes certificates received from a peer. + Validate and process certificates received from a peer. - Ensures each certificate's subject equals message.identityKey - Verifies signature - If certificates_requested is provided, enforces certifier/type/required fields - Attempts to decrypt fields using the verifier wallet Raises Exception on validation failure. """ - from bsv.auth.verifiable_certificate import VerifiableCertificate - - certificates = getattr(message, 'certificates', None) or (message.get('certificates', None) if isinstance(message, dict) else None) - identity_key = getattr(message, 'identityKey', None) or (message.get('identityKey', None) if isinstance(message, dict) else None) + certificates = _extract_message_certificates(message) + identity_key = _extract_message_identity_key(message) if not certificates: raise ValueError('No certificates were provided in the AuthMessage.') if identity_key is None: raise ValueError('identityKey must be provided in the AuthMessage.') - # Normalize certificates_requested into (allowed_certifiers, requested_types_map) - def _normalize_requested(req): - allowed_certifiers = [] - requested_types = {} - if req is None: - return allowed_certifiers, requested_types - try: - # RequestedCertificateSet - from bsv.auth.requested_certificate_set import RequestedCertificateSet - if isinstance(req, RequestedCertificateSet): - allowed_certifiers = list(getattr(req, 'certifiers', []) or []) - # For utils we expect plain string type keys; convert bytes keys to base64 strings - mapping = getattr(getattr(req, 'certificate_types', None), 'mapping', {}) or {} - requested_types = {base64.b64encode(k).decode('ascii'): list(v or []) for k, v in mapping.items()} - return allowed_certifiers, requested_types - except Exception: - pass - # dict-like - if isinstance(req, dict): - allowed_certifiers = req.get('certifiers') or req.get('Certifiers') or [] - types_dict = req.get('certificate_types') or req.get('certificateTypes') or req.get('types') or {} - # In utils tests, type keys are simple strings. Keep as-is. - for k, v in types_dict.items(): - requested_types[str(k)] = list(v or []) - return allowed_certifiers, requested_types - - allowed_certifiers, requested_types = _normalize_requested(certificates_requested) + allowed_certifiers, requested_types = _normalize_requested_for_utils(certificates_requested) for incoming in certificates: - # Extract fields as-is (tests expect plain strings, not decoded keys) - cert_type = incoming.get('type') - serial_number = incoming.get('serialNumber') or incoming.get('serial_number') - subject = incoming.get('subject') - certifier = incoming.get('certifier') - fields = incoming.get('fields') or {} - signature = incoming.get('signature') - keyring = incoming.get('keyring') or {} - - if subject != identity_key: - raise ValueError(f'The subject of one of your certificates ("{subject}") is not the same as the request sender ("{identity_key}").') - - # Instantiate VerifiableCertificate with backwards-compatible signature used in tests - try: - vc = VerifiableCertificate(cert_type, serial_number, subject, certifier, incoming.get('revocationOutpoint'), fields, keyring, signature) - except Exception: - # Fallback: if real class is present, try wrapping via real constructor - try: - from bsv.auth.certificate import Certificate as _Cert, Outpoint as _Out - from bsv.keys import PublicKey as _PK - subj_pk = _PK(subject) - cert_pk = _PK(certifier) if certifier else None - rev = incoming.get('revocationOutpoint') - rev_out = None - if isinstance(rev, dict): - txid = rev.get('txid') or rev.get('txID') or rev.get('txId') - index = rev.get('index') or rev.get('vout') - if txid is not None and index is not None: - rev_out = _Out(txid, int(index)) - base = _Cert(cert_type, serial_number, subj_pk, cert_pk, rev_out, fields, signature) - vc = VerifiableCertificate(base, keyring) - except Exception as e: - raise e - - # Signature verification + cert_type, serial_number, subject, certifier, fields, signature, keyring = _extract_incoming_fields(incoming) + + _ensure_subject_matches(subject, identity_key) + + vc = _build_verifiable_certificate(incoming, cert_type, serial_number, subject, certifier, fields, signature, keyring) + if not vc.verify(): raise ValueError(f'The signature for the certificate with serial number {serial_number} is invalid!') - # Requested constraints - if allowed_certifiers or requested_types: - if allowed_certifiers and certifier not in allowed_certifiers: - raise ValueError(f'Certificate with serial number {serial_number} has an unrequested certifier') - if requested_types and cert_type not in requested_types: - raise ValueError(f'Certificate with type {cert_type} was not requested') - required_fields = requested_types.get(cert_type, []) - for field in required_fields: - if field not in (fields or {}): - raise ValueError(f'Certificate missing required field: {field}') - - # Try to decrypt fields for the verifier - # Let decryption errors bubble up to the caller (as tests expect) - vc.decrypt_fields(None, verifier_wallet) \ No newline at end of file + _enforce_requested_constraints(allowed_certifiers, requested_types, cert_type, certifier, fields, serial_number) + + # Try to decrypt fields for the verifier (errors bubble up to caller) + vc.decrypt_fields(None, verifier_wallet) + + +# ------- Helpers below keep validate_certificates simple and testable ------- +def _extract_message_certificates(message): + return getattr(message, 'certificates', None) or (message.get('certificates', None) if isinstance(message, dict) else None) + + +def _extract_message_identity_key(message): + return getattr(message, 'identityKey', None) or (message.get('identityKey', None) if isinstance(message, dict) else None) + + +def _normalize_requested_for_utils(req): + allowed_certifiers = [] + requested_types = {} + if req is None: + return allowed_certifiers, requested_types + try: + # RequestedCertificateSet + from bsv.auth.requested_certificate_set import RequestedCertificateSet + if isinstance(req, RequestedCertificateSet): + allowed_certifiers = list(getattr(req, 'certifiers', []) or []) + # For utils we expect plain string type keys; convert bytes keys to base64 strings + mapping = getattr(getattr(req, 'certificate_types', None), 'mapping', {}) or {} + requested_types = {base64.b64encode(k).decode('ascii'): list(v or []) for k, v in mapping.items()} + return allowed_certifiers, requested_types + except Exception: + pass + # dict-like + if isinstance(req, dict): + allowed_certifiers = req.get('certifiers') or req.get('Certifiers') or [] + types_dict = req.get('certificate_types') or req.get('certificateTypes') or req.get('types') or {} + # In utils tests, type keys are simple strings. Keep as-is. + for k, v in types_dict.items(): + requested_types[str(k)] = list(v or []) + return allowed_certifiers, requested_types + + +def _extract_incoming_fields(incoming): + cert_type = incoming.get('type') + serial_number = incoming.get('serialNumber') or incoming.get('serial_number') + subject = incoming.get('subject') + certifier = incoming.get('certifier') + fields = incoming.get('fields') or {} + signature = incoming.get('signature') + keyring = incoming.get('keyring') or {} + return cert_type, serial_number, subject, certifier, fields, signature, keyring + + +def _ensure_subject_matches(subject, identity_key): + if subject != identity_key: + raise ValueError(f'The subject of one of your certificates ("{subject}") is not the same as the request sender ("{identity_key}").') + + +def _build_verifiable_certificate(incoming, cert_type, serial_number, subject, certifier, fields, signature, keyring): + from bsv.auth.verifiable_certificate import VerifiableCertificate + try: + return VerifiableCertificate(cert_type, serial_number, subject, certifier, incoming.get('revocationOutpoint'), fields, keyring, signature) + except Exception: + # Fallback: attempt to wrap a base Certificate if available + from bsv.auth.certificate import Certificate as _Cert, Outpoint as _Out + from bsv.keys import PublicKey as _PK + subj_pk = _PK(subject) + cert_pk = _PK(certifier) if certifier else None + rev = incoming.get('revocationOutpoint') + rev_out = None + if isinstance(rev, dict): + txid = rev.get('txid') or rev.get('txID') or rev.get('txId') + index = rev.get('index') or rev.get('vout') + if txid is not None and index is not None: + rev_out = _Out(txid, int(index)) + base = _Cert(cert_type, serial_number, subj_pk, cert_pk, rev_out, fields, signature) + return VerifiableCertificate(base, keyring) + + +def _enforce_requested_constraints(allowed_certifiers, requested_types, cert_type, certifier, fields, serial_number): + if not (allowed_certifiers or requested_types): + return + if allowed_certifiers and certifier not in allowed_certifiers: + raise ValueError(f'Certificate with serial number {serial_number} has an unrequested certifier') + if requested_types and cert_type not in requested_types: + raise ValueError(f'Certificate with type {cert_type} was not requested') + required_fields = requested_types.get(cert_type, []) + for field in required_fields: + if field not in (fields or {}): + raise ValueError(f'Certificate missing required field: {field}') \ No newline at end of file diff --git a/bsv/wallet/wallet_impl.py b/bsv/wallet/wallet_impl.py index e85c601..670ef35 100644 --- a/bsv/wallet/wallet_impl.py +++ b/bsv/wallet/wallet_impl.py @@ -22,7 +22,7 @@ def _check_permission(self, action: str) -> None: allowed = self.permission_callback(action) else: # Default for CLI: Ask the user for permission - resp = input(f"[Wallet] {action} を許可しますか? [y/N]: ") + resp = input(f"[Wallet] Allow {action}? [y/N]: ") allowed = resp.strip().lower() in ("y", "yes") if os.getenv("BSV_DEBUG", "0") == "1": print(f"[DEBUG WalletImpl._check_permission] action={action!r} allowed={allowed}") @@ -68,7 +68,7 @@ def get_public_key(self, ctx: Any, args: Dict, originator: str) -> Dict: if os.getenv("BSV_DEBUG", "0") == "1": print(f"[DEBUG WalletImpl.get_public_key] originator={originator} seek_permission={seek_permission} args={args}") if seek_permission: - self._check_permission("公開鍵取得 (get_public_key)") + self._check_permission("Get public key") if args.get("identityKey", False): return {"publicKey": self.public_key.hex()} protocol_id = args.get("protocolID") @@ -90,44 +90,13 @@ def get_public_key(self, ctx: Any, args: Dict, originator: str) -> Dict: def encrypt(self, ctx: Any, args: Dict, originator: str) -> Dict: try: encryption_args = args.get("encryption_args", {}) - seek_permission = encryption_args.get("seekPermission") or encryption_args.get("seek_permission") if os.getenv("BSV_DEBUG", "0") == "1": print(f"[DEBUG WalletImpl.encrypt] originator={originator} enc_args={encryption_args}") - if seek_permission: - self._check_permission("暗号化 (encrypt)") + self._maybe_seek_permission("Encrypt", encryption_args) plaintext = args.get("plaintext") if plaintext is None: return {"error": "encrypt: plaintext is required"} - protocol_id = encryption_args.get("protocol_id") - key_id = encryption_args.get("key_id") - counterparty = encryption_args.get("counterparty") - for_self = encryption_args.get("forSelf", False) - if protocol_id and key_id: - if isinstance(protocol_id, dict): - protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) - else: - protocol = protocol_id - # normalize counterparty for KeyDeriver - if isinstance(counterparty, dict): - inner = counterparty.get("counterparty") - if isinstance(inner, (bytes, str)): - inner = PublicKey(inner) - cp = Counterparty(counterparty.get("type", CounterpartyType.OTHER), inner) - else: - if isinstance(counterparty, (bytes, str)): - cp = Counterparty(CounterpartyType.OTHER, PublicKey(counterparty)) - elif isinstance(counterparty, PublicKey): - cp = Counterparty(CounterpartyType.OTHER, counterparty) - else: - cp = Counterparty(CounterpartyType.SELF) - pubkey = self.key_deriver.derive_public_key(protocol, key_id, cp, for_self) - else: - if isinstance(counterparty, PublicKey): - pubkey = counterparty - elif isinstance(counterparty, str): - pubkey = PublicKey(counterparty) - else: - pubkey = self.public_key + pubkey = self._resolve_encryption_public_key(encryption_args) ciphertext = pubkey.encrypt(plaintext) return {"ciphertext": ciphertext} except Exception as e: @@ -136,49 +105,13 @@ def encrypt(self, ctx: Any, args: Dict, originator: str) -> Dict: def decrypt(self, ctx: Any, args: Dict, originator: str) -> Dict: try: encryption_args = args.get("encryption_args", {}) - seek_permission = encryption_args.get("seekPermission") or encryption_args.get("seek_permission") if os.getenv("BSV_DEBUG", "0") == "1": print(f"[DEBUG WalletImpl.decrypt] originator={originator} enc_args={encryption_args}") - if seek_permission: - self._check_permission("復号 (decrypt)") + self._maybe_seek_permission("Decrypt", encryption_args) ciphertext = args.get("ciphertext") if ciphertext is None: return {"error": "decrypt: ciphertext is required"} - protocol_id = encryption_args.get("protocol_id") - key_id = encryption_args.get("key_id") - counterparty = encryption_args.get("counterparty") - for_self = encryption_args.get("forSelf", False) - if protocol_id and key_id: - if isinstance(protocol_id, dict): - protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) - else: - protocol = protocol_id - # normalize counterparty (sender pub) - if isinstance(counterparty, dict): - inner = counterparty.get("counterparty") - if isinstance(inner, (bytes, str)): - inner = PublicKey(inner) - cp = Counterparty(counterparty.get("type", CounterpartyType.OTHER), inner) - else: - if isinstance(counterparty, (bytes, str)): - cp = Counterparty(CounterpartyType.OTHER, PublicKey(counterparty)) - elif isinstance(counterparty, PublicKey): - cp = Counterparty(CounterpartyType.OTHER, counterparty) - else: - cp = Counterparty(CounterpartyType.SELF) - derived_priv = self.key_deriver.derive_private_key(protocol, key_id, cp) - if os.getenv("BSV_DEBUG", "0") == "1": - print(f"[DEBUG WalletImpl.decrypt] derived_priv int={derived_priv.int():x} ciphertext_len={len(ciphertext)}") - try: - plaintext = derived_priv.decrypt(ciphertext) - if os.getenv("BSV_DEBUG", "0") == "1": - print(f"[DEBUG WalletImpl.decrypt] decrypt success, plaintext={plaintext.hex()}") - except Exception as dec_err: - if os.getenv("BSV_DEBUG", "0") == "1": - print(f"[DEBUG WalletImpl.decrypt] decrypt failed with derived key: {dec_err}") - plaintext = b"" - else: - plaintext = self.private_key.decrypt(ciphertext) + plaintext = self._perform_decrypt_with_args(encryption_args, ciphertext) return {"plaintext": plaintext} except Exception as e: return {"error": f"decrypt: {e}"} @@ -295,7 +228,11 @@ def verify_hmac(self, ctx: Any, args: Dict, originator: str) -> Dict: except Exception as e: return {"error": f"verify_hmac: {e}"} - def abort_action(self, *a, **k): pass + def abort_action(self, *a, **k): + # NOTE: This mock wallet does not manage long-running actions, so there is + # nothing to abort. The method is intentionally left empty to satisfy the + # interface and to document that abort semantics are a no-op in tests. + pass def acquire_certificate(self, ctx: Any, args: Dict, originator: str) -> Dict: # store minimal certificate record for listing/discovery record = { @@ -408,31 +345,39 @@ def list_outputs(self, ctx: Any, args: Dict, originator: str) -> Dict: # Return outputs for the requested basket from the most recent action, and include a BEEF include = (args.get("include") or "").lower() basket = args.get("basket", "") - outputs_desc = [] - # Find the most recent action with outputs matching the basket + outputs_desc = self._find_outputs_for_basket(basket, args) + if os.getenv("REGISTRY_DEBUG") == "1": + print("[DEBUG list_outputs] basket", basket, "outputs_desc", outputs_desc) + beef_bytes = self._build_beef_for_outputs(outputs_desc) + res = {"outputs": self._format_outputs_result(outputs_desc, basket)} + if "entire" in include or "transaction" in include: + res["BEEF"] = beef_bytes + return res + + # ---- Helpers to reduce cognitive complexity in list_outputs ---- + def _find_outputs_for_basket(self, basket: str, args: Dict) -> List[Dict[str, Any]]: + outputs_desc: List[Dict[str, Any]] = [] for action in reversed(self._actions): outs = action.get("outputs") or [] filtered = [o for o in outs if (not basket) or (o.get("basket") == basket)] if filtered: outputs_desc = filtered break - if not outputs_desc: - # Fallback to one mock output - outputs_desc = [ - { - "outputIndex": 0, - "satoshis": 1000, - "lockingScript": b"\x51", - "spendable": True, - "outputDescription": "mock", - "basket": basket, - "tags": args.get("tags", []) or [], - "customInstructions": None, - } - ] - # Build Transaction with these outputs for BEEF inclusion; ensure locking script is the one we stored - if os.getenv("REGISTRY_DEBUG") == "1": - print("[DEBUG list_outputs] basket", basket, "outputs_desc", outputs_desc) + if outputs_desc: + return outputs_desc + # Fallback to one mock output + return [{ + "outputIndex": 0, + "satoshis": 1000, + "lockingScript": b"\x51", + "spendable": True, + "outputDescription": "mock", + "basket": basket, + "tags": args.get("tags", []) or [], + "customInstructions": None, + }] + + def _build_beef_for_outputs(self, outputs_desc: List[Dict[str, Any]]) -> bytes: try: from bsv.transaction import Transaction from bsv.transaction_output import TransactionOutput @@ -440,24 +385,20 @@ def list_outputs(self, ctx: Any, args: Dict, originator: str) -> Dict: tx = Transaction() for o in outputs_desc: ls_hex = o.get("lockingScript") - if isinstance(ls_hex, str): - ls_bytes = bytes.fromhex(ls_hex) - else: - ls_bytes = ls_hex or b"\x51" + ls_bytes = bytes.fromhex(ls_hex) if isinstance(ls_hex, str) else (ls_hex or b"\x51") to = TransactionOutput(Script(ls_bytes), int(o.get("satoshis", 0))) tx.add_output(to) - beef_bytes = tx.to_beef() + return tx.to_beef() except Exception: - beef_bytes = b"" - # Prepare result - result_outputs = [] + return b"" + + def _format_outputs_result(self, outputs_desc: List[Dict[str, Any]], basket: str) -> List[Dict[str, Any]]: + result_outputs: List[Dict[str, Any]] = [] for idx, o in enumerate(outputs_desc): - # ensure lockingScript hex string ls_hex = o.get("lockingScript") if not isinstance(ls_hex, str): ls_hex = (ls_hex or b"\x51").hex() - - ro = { + result_outputs.append({ "outputIndex": int(o.get("outputIndex", idx)), "satoshis": int(o.get("satoshis", 0)), "lockingScript": ls_hex, @@ -467,12 +408,52 @@ def list_outputs(self, ctx: Any, args: Dict, originator: str) -> Dict: "tags": o.get("tags") or [], "customInstructions": o.get("customInstructions"), "txid": "00" * 32, - } - result_outputs.append(ro) - res = {"outputs": result_outputs} - if "entire" in include or "transaction" in include: - res["BEEF"] = beef_bytes - return res + }) + return result_outputs + + # ---- Shared helpers for encrypt/decrypt ---- + def _maybe_seek_permission(self, action_label: str, enc_args: Dict) -> None: + seek_permission = enc_args.get("seekPermission") or enc_args.get("seek_permission") + if seek_permission: + self._check_permission(action_label) + + def _resolve_encryption_public_key(self, enc_args: Dict) -> PublicKey: + protocol_id = enc_args.get("protocol_id") + key_id = enc_args.get("key_id") + counterparty = enc_args.get("counterparty") + for_self = enc_args.get("forSelf", False) + if protocol_id and key_id: + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) if isinstance(protocol_id, dict) else protocol_id + cp = self._normalize_counterparty(counterparty) + return self.key_deriver.derive_public_key(protocol, key_id, cp, for_self) + # Fallbacks + if isinstance(counterparty, PublicKey): + return counterparty + if isinstance(counterparty, str): + return PublicKey(counterparty) + return self.public_key + + def _perform_decrypt_with_args(self, enc_args: Dict, ciphertext: bytes) -> bytes: + protocol_id = enc_args.get("protocol_id") + key_id = enc_args.get("key_id") + counterparty = enc_args.get("counterparty") + if protocol_id and key_id: + protocol = Protocol(protocol_id.get("securityLevel", 0), protocol_id.get("protocol", "")) if isinstance(protocol_id, dict) else protocol_id + cp = self._normalize_counterparty(counterparty) + derived_priv = self.key_deriver.derive_private_key(protocol, key_id, cp) + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] derived_priv int={derived_priv.int():x} ciphertext_len={len(ciphertext)}") + try: + plaintext = derived_priv.decrypt(ciphertext) + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] decrypt success, plaintext={plaintext.hex()}") + except Exception as dec_err: + if os.getenv("BSV_DEBUG", "0") == "1": + print(f"[DEBUG WalletImpl.decrypt] decrypt failed with derived key: {dec_err}") + plaintext = b"" + return plaintext + # Fallback path + return self.private_key.decrypt(ciphertext) def prove_certificate(self, ctx: Any, args: Dict, originator: str) -> Dict: return {"keyringForVerifier": {}, "verifier": args.get("verifier", b"")} def relinquish_certificate(self, ctx: Any, args: Dict, originator: str) -> Dict: @@ -504,7 +485,7 @@ def reveal_counterparty_key_linkage(self, ctx: Any, args: Dict, originator: str) if seek_permission: # Ask the user (or callback) for permission - self._check_permission("鍵リンク開示 (counterparty)") + self._check_permission("Reveal counterparty key linkage") # Real implementation would compute and return linkage data here. For test purposes # we return an empty dict which the serializer converts to an empty payload. @@ -524,7 +505,7 @@ def reveal_specific_key_linkage(self, ctx: Any, args: Dict, originator: str) -> print(f"[DEBUG WalletImpl.reveal_specific_key_linkage] originator={originator} seek_permission={seek_permission} args={args}") if seek_permission: - self._check_permission("鍵リンク開示 (specific)") + self._check_permission("Reveal specific key linkage") return {} except Exception as e: