From 8e76ad4410cd979c856170dcd11a5d51cd7490e9 Mon Sep 17 00:00:00 2001 From: fametrano Date: Fri, 3 Feb 2023 13:13:06 +0100 Subject: [PATCH] speeded up BIP32 public derivation by caching --- btclib/bip32/bip32.py | 38 ++++++++++++++++++++++---------------- btclib/ec/curve_group.py | 4 ++-- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/btclib/bip32/bip32.py b/btclib/bip32/bip32.py index 8f09c226..10b369e8 100644 --- a/btclib/bip32/bip32.py +++ b/btclib/bip32/bip32.py @@ -34,6 +34,7 @@ from __future__ import annotations import copy +import functools import hmac from dataclasses import dataclass from typing import Union @@ -277,7 +278,7 @@ def xpub_from_xprv(xprv: BIP32Key) -> str: @dataclass -class _ExtendedBIP32KeyData(BIP32KeyData): +class _BIP32KeyData(BIP32KeyData): # extensions used to cache intermediate results # in multi-level derivation: do not rely on them elsewhere @@ -308,17 +309,15 @@ def __init__( if check_validity: self.assert_valid() + def __tuple(self) -> tuple[bytes, int, bytes, bytes]: + return (self.parent_fingerprint, self.index, self.chain_code, self.key) -def __child_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None: - xkey.depth += 1 - xkey.index = index - if xkey.is_private: - __private_key_derivation(xkey, index) - else: - __public_key_derivation(xkey, index) + def __hash__(self) -> int: + return hash(self.__tuple()) -def __private_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None: +def __prv_key_derivation(xkey: _BIP32KeyData, index: int) -> None: + xkey.index = index Q_bytes = bytes_from_point(mult(xkey.prv_key_int)) xkey.parent_fingerprint = hash160(Q_bytes)[:4] hmac_ = ( @@ -341,10 +340,10 @@ def __private_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None: xkey.pub_key_point = INF -def __public_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None: +@functools.lru_cache() # results are cached to increase efficiency +def __pub_key_derivation(xkey: _BIP32KeyData, index: int) -> None: + xkey.index = index xkey.parent_fingerprint = hash160(xkey.key)[:4] - if xkey.is_hardened: - raise BTClibValueError("invalid hardened derivation from public key") hmac_ = hmac.new( xkey.chain_code, xkey.key + index.to_bytes(4, byteorder="big", signed=False), @@ -370,16 +369,14 @@ def _derive( err_msg = f"final depth greater than 255: {final_depth}" raise BTClibValueError(err_msg) - xkey = _ExtendedBIP32KeyData( + xkey = _BIP32KeyData( version=xkey.version, - depth=xkey.depth, + depth=final_depth, parent_fingerprint=xkey.parent_fingerprint, index=xkey.index, chain_code=xkey.chain_code, key=xkey.key, ) - for index in indexes: - __child_key_derivation(xkey, index) if forced_version: if xkey.version in XPRV_VERSIONS_ALL: @@ -393,6 +390,15 @@ def _derive( raise BTClibValueError(err_msg) xkey.version = fversion + if xkey.is_private: + for index in indexes: + __prv_key_derivation(xkey, index) + else: + if any(index >= 0x80000000 for index in indexes): + raise BTClibValueError("invalid hardened derivation from public key") + for index in indexes: + __pub_key_derivation(xkey, index) + return xkey diff --git a/btclib/ec/curve_group.py b/btclib/ec/curve_group.py index e1dc1bc6..651ad9f9 100644 --- a/btclib/ec/curve_group.py +++ b/btclib/ec/curve_group.py @@ -481,7 +481,7 @@ def multiples(Q: JacPoint, size: int, ec: CurveGroup) -> list[JacPoint]: MAX_W = 5 -@functools.lru_cache() # least recently used cache +@functools.lru_cache() # results are cached to increase efficiency def cached_multiples(Q: JacPoint, ec: CurveGroup) -> list[JacPoint]: T = [INFJ, Q] for i in range(3, 2**MAX_W, 2): @@ -490,7 +490,7 @@ def cached_multiples(Q: JacPoint, ec: CurveGroup) -> list[JacPoint]: return T -@functools.lru_cache() +@functools.lru_cache() # results are cached to increase efficiency def cached_multiples_fixwind( Q: JacPoint, ec: CurveGroup, w: int = 4 ) -> list[list[JacPoint]]: