Skip to content

Commit

Permalink
speeded up BIP32 public derivation by caching
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed Feb 3, 2023
1 parent 7766535 commit 94c8949
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
38 changes: 22 additions & 16 deletions btclib/bip32/bip32.py
Expand Up @@ -34,6 +34,7 @@
from __future__ import annotations

import copy
import functools
import hmac
from dataclasses import dataclass
from typing import Union
Expand Down Expand Up @@ -277,7 +278,7 @@ def xpub_from_xprv(xprv: BIP32Key) -> str:


@dataclass
class _ExtendedBIP32KeyData(BIP32KeyData):
class _BIP32KeyData(BIP32KeyData):
# extensions used to cache intermediate results
# in multi-level derivation: do not rely on them elsewhere

Expand Down Expand Up @@ -308,17 +309,15 @@ def __init__(
if check_validity:
self.assert_valid()

def __tuple(self) -> tuple[bytes, int, bytes, bytes]:
return (self.parent_fingerprint, self.index, self.chain_code, self.key)

def __child_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
xkey.depth += 1
xkey.index = index
if xkey.is_private:
__private_key_derivation(xkey, index)
else:
__public_key_derivation(xkey, index)
def __hash__(self) -> int:
return hash(self.__tuple())


def __private_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
def __prv_key_derivation(xkey: _BIP32KeyData, index: int) -> None:
xkey.index = index
Q_bytes = bytes_from_point(mult(xkey.prv_key_int))
xkey.parent_fingerprint = hash160(Q_bytes)[:4]
hmac_ = (
Expand All @@ -341,10 +340,10 @@ def __private_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
xkey.pub_key_point = INF


def __public_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
@functools.lru_cache() # results are cached to increase efficiency
def __pub_key_derivation(xkey: _BIP32KeyData, index: int) -> None:
xkey.index = index
xkey.parent_fingerprint = hash160(xkey.key)[:4]
if xkey.is_hardened:
raise BTClibValueError("invalid hardened derivation from public key")
hmac_ = hmac.new(
xkey.chain_code,
xkey.key + index.to_bytes(4, byteorder="big", signed=False),
Expand All @@ -370,16 +369,14 @@ def _derive(
err_msg = f"final depth greater than 255: {final_depth}"
raise BTClibValueError(err_msg)

xkey = _ExtendedBIP32KeyData(
xkey = _BIP32KeyData(
version=xkey.version,
depth=xkey.depth,
depth=final_depth,
parent_fingerprint=xkey.parent_fingerprint,
index=xkey.index,
chain_code=xkey.chain_code,
key=xkey.key,
)
for index in indexes:
__child_key_derivation(xkey, index)

if forced_version:
if xkey.version in XPRV_VERSIONS_ALL:
Expand All @@ -393,6 +390,15 @@ def _derive(
raise BTClibValueError(err_msg)
xkey.version = fversion

if xkey.is_private:
for index in indexes:
__prv_key_derivation(xkey, index)
else:
if any(index >= 0x80000000 for index in indexes):
raise BTClibValueError("invalid hardened derivation from public key")
for index in indexes:
__pub_key_derivation(xkey, index)

return xkey


Expand Down
4 changes: 2 additions & 2 deletions btclib/ec/curve_group.py
Expand Up @@ -481,7 +481,7 @@ def multiples(Q: JacPoint, size: int, ec: CurveGroup) -> list[JacPoint]:
MAX_W = 5


@functools.lru_cache() # least recently used cache
@functools.lru_cache() # results are cached to increase efficiency
def cached_multiples(Q: JacPoint, ec: CurveGroup) -> list[JacPoint]:
T = [INFJ, Q]
for i in range(3, 2**MAX_W, 2):
Expand All @@ -490,7 +490,7 @@ def cached_multiples(Q: JacPoint, ec: CurveGroup) -> list[JacPoint]:
return T


@functools.lru_cache()
@functools.lru_cache() # results are cached to increase efficiency
def cached_multiples_fixwind(
Q: JacPoint, ec: CurveGroup, w: int = 4
) -> list[list[JacPoint]]:
Expand Down

0 comments on commit 94c8949

Please sign in to comment.