Skip to content

Commit

Permalink
fixed error message and cleaned up code (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed Feb 5, 2023
1 parent db08bef commit fd643cf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 22 deletions.
34 changes: 16 additions & 18 deletions btclib/bip32/bip32.py
Expand Up @@ -277,7 +277,7 @@ def xpub_from_xprv(xprv: BIP32Key) -> str:


@dataclass
class _ExtendedBIP32KeyData(BIP32KeyData):
class _BIP32KeyData(BIP32KeyData):
# extensions used to cache intermediate results
# in multi-level derivation: do not rely on them elsewhere

Expand Down Expand Up @@ -309,16 +309,8 @@ def __init__(
self.assert_valid()


def __child_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
xkey.depth += 1
def __prv_key_derivation(xkey: _BIP32KeyData, index: int) -> None:
xkey.index = index
if xkey.is_private:
__private_key_derivation(xkey, index)
else:
__public_key_derivation(xkey, index)


def __private_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
Q_bytes = bytes_from_point(mult(xkey.prv_key_int))
xkey.parent_fingerprint = hash160(Q_bytes)[:4]
hmac_ = (
Expand All @@ -341,10 +333,9 @@ def __private_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
xkey.pub_key_point = INF


def __public_key_derivation(xkey: _ExtendedBIP32KeyData, index: int) -> None:
def __pub_key_derivation(xkey: _BIP32KeyData, index: int) -> None:
xkey.index = index
xkey.parent_fingerprint = hash160(xkey.key)[:4]
if xkey.is_hardened:
raise BTClibValueError("invalid hardened derivation from public key")
hmac_ = hmac.new(
xkey.chain_code,
xkey.key + index.to_bytes(4, byteorder="big", signed=False),
Expand All @@ -370,16 +361,14 @@ def _derive(
err_msg = f"final depth greater than 255: {final_depth}"
raise BTClibValueError(err_msg)

xkey = _ExtendedBIP32KeyData(
xkey = _BIP32KeyData(
version=xkey.version,
depth=xkey.depth,
depth=final_depth,
parent_fingerprint=xkey.parent_fingerprint,
index=xkey.index,
chain_code=xkey.chain_code,
key=xkey.key,
)
for index in indexes:
__child_key_derivation(xkey, index)

if forced_version:
if xkey.version in XPRV_VERSIONS_ALL:
Expand All @@ -393,6 +382,15 @@ def _derive(
raise BTClibValueError(err_msg)
xkey.version = fversion

if xkey.is_private:
for index in indexes:
__prv_key_derivation(xkey, index)
else:
if any(index >= 0x80000000 for index in indexes):
raise BTClibValueError("invalid hardened derivation from public key")
for index in indexes:
__pub_key_derivation(xkey, index)

return xkey


Expand Down Expand Up @@ -438,7 +436,7 @@ def _derive_from_account(
if address_index >= 0x80000000:
raise BTClibValueError("invalid private derivation at address index level")
if address_index > max_index:
raise BTClibValueError(f"too high address index: {branch}")
raise BTClibValueError(f"too high address index: {address_index}")

return _derive(mxkey, f"m/{branch}/{address_index}")

Expand Down
4 changes: 2 additions & 2 deletions btclib/ec/curve_group.py
Expand Up @@ -481,7 +481,7 @@ def multiples(Q: JacPoint, size: int, ec: CurveGroup) -> list[JacPoint]:
MAX_W = 5


@functools.lru_cache() # least recently used cache
@functools.lru_cache() # results are cached to increase efficiency
def cached_multiples(Q: JacPoint, ec: CurveGroup) -> list[JacPoint]:
T = [INFJ, Q]
for i in range(3, 2**MAX_W, 2):
Expand All @@ -490,7 +490,7 @@ def cached_multiples(Q: JacPoint, ec: CurveGroup) -> list[JacPoint]:
return T


@functools.lru_cache()
@functools.lru_cache() # results are cached to increase efficiency
def cached_multiples_fixwind(
Q: JacPoint, ec: CurveGroup, w: int = 4
) -> list[list[JacPoint]]:
Expand Down
4 changes: 2 additions & 2 deletions tests/bip32/test_bip32.py
Expand Up @@ -300,15 +300,15 @@ def test_derive_from_account() -> None:
with pytest.raises(BTClibValueError, match=err_msg):
derive_from_account(mxpub, 0xFFFF + 1, 0)

err_msg = "invalid branch: "
err_msg = "invalid branch: 2"
with pytest.raises(BTClibValueError, match=err_msg):
derive_from_account(mxpub, 2, 0)

err_msg = "invalid private derivation at address index level"
with pytest.raises(BTClibValueError, match=err_msg):
derive_from_account(mxpub, 0, 0x80000000)

err_msg = "too high address index: "
err_msg = "too high address index: 65536"
with pytest.raises(BTClibValueError, match=err_msg):
derive_from_account(mxpub, 0, 0xFFFF + 1)

Expand Down

0 comments on commit fd643cf

Please sign in to comment.