Skip to content

Commit

Permalink
used cls.serialize()
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed Nov 11, 2020
1 parent 9402694 commit 5b832c4
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 101 deletions.
39 changes: 6 additions & 33 deletions btclib/psbt.py
Expand Up @@ -22,7 +22,6 @@

from . import script, varint
from .alias import ScriptToken
from .bip32 import bytes_from_bip32_path
from .psbt_in import PartialSigs, PsbtIn
from .psbt_out import HdKeyPaths, ProprietaryData, PsbtOut, UnknownData
from .scriptpubkey import payload_from_scriptPubKey
Expand Down Expand Up @@ -64,27 +63,21 @@ def deserialize(cls: Type[_Psbt], data: bytes, assert_valid: bool = True) -> _Ps
for key, value in global_map.items():
if key[0:1] == PSBT_GLOBAL_UNSIGNED_TX:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert not out.tx.nVersion, "duplicated transaction"
out.tx = Tx.deserialize(value) # legacy trensaction
elif key[0:1] == PSBT_GLOBAL_VERSION:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert not out.version, "duplicated version"
assert len(value) == 4, f"invalid version length: {len(value)}"
out.version = int.from_bytes(value, "little")
elif key[0:1] == PSBT_GLOBAL_XPUB:
# why extended key here?
assert len(key) == 78 + 1, f"invalid key length: {len(key)}"
# TODO: assert not duplicated?
out.hd_keypaths.add_hd_keypath(key[1:], value[:4], value[4:])
elif key[0:1] == PSBT_GLOBAL_PROPRIETARY:
# TODO: assert not duplicated?
prefix = varint.decode(key[1:])
if prefix not in out.proprietary.data:
out.proprietary.data[prefix] = {}
key = key[1 + len(varint.encode(prefix)) :]
out.proprietary.data[prefix][key.hex()] = value.hex()
else: # unknown keys
# TODO: assert not duplicated?
out.unknown.data[key.hex()] = value.hex()

assert out.tx.nVersion, "missing transaction"
Expand Down Expand Up @@ -112,32 +105,12 @@ def serialize(self, assert_valid: bool = True) -> bytes:
if self.version:
out += b"\x01" + PSBT_GLOBAL_VERSION
out += b"\x04" + self.version.to_bytes(4, "little")
if self.hd_keypaths:
for pubkey, hd_keypath in self.hd_keypaths.hd_keypaths.items():
pubkey_bytes = PSBT_GLOBAL_XPUB + bytes.fromhex(pubkey)
out += varint.encode(len(pubkey_bytes)) + pubkey_bytes
keypath = bytes.fromhex(hd_keypath["fingerprint"])
keypath += bytes_from_bip32_path(
hd_keypath["derivation_path"], "little"
)
out += varint.encode(len(keypath)) + keypath
if self.proprietary:
for (owner, dictionary) in self.proprietary.data.items():
for key_p, value_p in dictionary.items():
key_bytes = (
PSBT_GLOBAL_PROPRIETARY
+ varint.encode(owner)
+ bytes.fromhex(key_p)
)
out += varint.encode(len(key_bytes)) + key_bytes
t = bytes.fromhex(value_p)
out += varint.encode(len(t)) + t
if self.unknown:
for key_u, value_u in self.unknown.data.items():
t = bytes.fromhex(key_u)
out += varint.encode(len(t)) + t
t = bytes.fromhex(value_u)
out += varint.encode(len(t)) + t
if self.hd_keypaths.hd_keypaths:
out += self.hd_keypaths.serialize(PSBT_GLOBAL_XPUB, assert_valid)
if self.proprietary.data:
out += self.proprietary.serialize(PSBT_GLOBAL_PROPRIETARY, assert_valid)
if self.unknown.data:
out += self.unknown.serialize(assert_valid)

out += PSBT_DELIMITER
for input_map in self.inputs:
Expand Down
43 changes: 6 additions & 37 deletions btclib/psbt_in.py
Expand Up @@ -19,7 +19,6 @@
from dataclasses_json import DataClassJsonMixin, config

from . import dsa, varint
from .bip32 import bytes_from_bip32_path
from .psbt_out import HdKeyPaths, ProprietaryData, UnknownData
from .script import SIGHASHES
from .tx import Tx
Expand Down Expand Up @@ -91,52 +90,42 @@ def deserialize(
for key, value in input_map.items():
if key[0:1] == PSBT_IN_NON_WITNESS_UTXO:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.non_witness_utxo is None, "duplicated non_witness_utxo"
out.non_witness_utxo = Tx.deserialize(value)
elif key[0:1] == PSBT_IN_WITNESS_UTXO:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.witness_utxo is None, "duplicated witness_utxo"
out.witness_utxo = TxOut.deserialize(value)
elif key[0:1] == PSBT_IN_PARTIAL_SIG:
assert len(key) == 33 + 1, f"invalid key length: {len(key)}"
out.partial_sigs.sigs[key[1:].hex()] = value.hex()
elif key[0:1] == PSBT_IN_SIGHASH_TYPE:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.sighash is None, "duplicated sighash"
assert len(value) == 4
out.sighash = int.from_bytes(value, "little")
elif key[0:1] == PSBT_IN_FINAL_SCRIPTSIG:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.final_script_sig == b"", "duplicated final_script_sig"
out.final_script_sig = value
elif key[0:1] == PSBT_IN_FINAL_SCRIPTWITNESS:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert not out.final_script_witness, "duplicated final_script_witness"
out.final_script_witness = witness_deserialize(value)
elif key[0:1] == PSBT_IN_POR_COMMITMENT:
assert len(key) == 1, f"invalid key length: {len(key)}"
out.por_commitment = value.decode("utf-8") # TODO: see bip127
elif key[0:1] == PSBT_IN_REDEEM_SCRIPT:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.redeem_script == b"", "duplicated redeem_script"
out.redeem_script = value
elif key[0:1] == PSBT_IN_WITNESS_SCRIPT:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.witness_script == b"", "duplicated witness_script"
out.witness_script = value
elif key[0:1] == PSBT_IN_BIP32_DERIVATION:
assert len(key) == 33 + 1, f"invalid key length: {len(key)}"
# TODO: assert not duplicated?
out.hd_keypaths.add_hd_keypath(key[1:], value[:4], value[4:])
elif key[0:1] == PSBT_IN_PROPRIETARY:
# TODO: assert not duplicated?
prefix = varint.decode(key[1:])
if prefix not in out.proprietary.data:
out.proprietary.data[prefix] = {}
key = key[1 + len(varint.encode(prefix)) :]
out.proprietary.data[prefix][key.hex()] = value.hex()
else: # unknown keys
# TODO: assert not duplicated?
out.unknown.data[key.hex()] = value.hex()

if assert_valid:
Expand Down Expand Up @@ -183,32 +172,12 @@ def serialize(self, assert_valid: bool = True) -> bytes:
out += b"\x01" + PSBT_IN_POR_COMMITMENT
c = self.por_commitment.encode("utf-8")
out += varint.encode(len(c)) + c
if self.hd_keypaths:
for pubkey, hd_keypath in self.hd_keypaths.hd_keypaths.items():
pubkey_bytes = PSBT_IN_BIP32_DERIVATION + bytes.fromhex(pubkey)
out += varint.encode(len(pubkey_bytes)) + pubkey_bytes
keypath = bytes.fromhex(hd_keypath["fingerprint"])
keypath += bytes_from_bip32_path(
hd_keypath["derivation_path"], "little"
)
out += varint.encode(len(keypath)) + keypath
if self.proprietary:
for (owner, dictionary) in self.proprietary.data.items():
for key_p, value_p in dictionary.items():
key_bytes = (
PSBT_IN_PROPRIETARY
+ varint.encode(owner)
+ bytes.fromhex(key_p)
)
out += varint.encode(len(key_bytes)) + key_bytes
t = bytes.fromhex(value_p)
out += varint.encode(len(t)) + t
if self.unknown:
for key_u, value_u in self.unknown.data.items():
t = bytes.fromhex(key_u)
out += varint.encode(len(t)) + t
t = bytes.fromhex(value_u)
out += varint.encode(len(t)) + t
if self.hd_keypaths.hd_keypaths:
out += self.hd_keypaths.serialize(PSBT_IN_BIP32_DERIVATION, assert_valid)
if self.proprietary.data:
out += self.proprietary.serialize(PSBT_IN_PROPRIETARY, assert_valid)
if self.unknown.data:
out += self.unknown.serialize(assert_valid)

return out

Expand Down
117 changes: 86 additions & 31 deletions btclib/psbt_out.py
Expand Up @@ -42,6 +42,9 @@ def _pubkey_to_hex_string(pubkey: PubKey) -> str:
return pubkey.hex()


_HdKeyPaths = TypeVar("_HdKeyPaths", bound="HdKeyPaths")


@dataclass
class HdKeyPaths(DataClassJsonMixin):
hd_keypaths: Dict[str, Dict[str, str]] = field(default_factory=dict)
Expand All @@ -67,22 +70,98 @@ def get_hd_keypath(self, key: PubKey) -> Tuple[str, str]:
entry = self.hd_keypaths[key_str]
return entry["fingerprint"], entry["derivation_path"]

@classmethod
def deserialize(cls: Type[_HdKeyPaths], assert_valid: bool = True) -> _HdKeyPaths:
out = cls()

if assert_valid:
out.assert_valid()
return out

def serialize(self, marker: bytes, assert_valid: bool = True) -> bytes:

if assert_valid:
self.assert_valid()

out = b""
for pubkey, hd_keypath in self.hd_keypaths.items():
pubkey_bytes = marker + bytes.fromhex(pubkey)
out += varint.encode(len(pubkey_bytes)) + pubkey_bytes
keypath = bytes.fromhex(hd_keypath["fingerprint"])
keypath += bytes_from_bip32_path(hd_keypath["derivation_path"], "little")
out += varint.encode(len(keypath)) + keypath

return out

def assert_valid(self) -> None:
pass


_ProprietaryData = TypeVar("_ProprietaryData", bound="ProprietaryData")


@dataclass
class ProprietaryData(DataClassJsonMixin):
data: Dict[int, Dict[str, str]] = field(default_factory=dict)

@classmethod
def deserialize(
cls: Type[_ProprietaryData], assert_valid: bool = True
) -> _ProprietaryData:
out = cls()

if assert_valid:
out.assert_valid()
return out

def serialize(self, marker: bytes, assert_valid: bool = True) -> bytes:

if assert_valid:
self.assert_valid()

out = b""
for (owner, dictionary) in self.data.items():
for key_p, value_p in dictionary.items():
key_bytes = marker + varint.encode(owner) + bytes.fromhex(key_p)
out += varint.encode(len(key_bytes)) + key_bytes
t = bytes.fromhex(value_p)
out += varint.encode(len(t)) + t

return out

def assert_valid(self) -> None:
pass


_UnknownData = TypeVar("_UnknownData", bound="UnknownData")


@dataclass
class UnknownData(DataClassJsonMixin):
data: Dict[str, str] = field(default_factory=dict)

@classmethod
def deserialize(cls: Type[_UnknownData], assert_valid: bool = True) -> _UnknownData:
out = cls()

if assert_valid:
out.assert_valid()
return out

def serialize(self, assert_valid: bool = True) -> bytes:

if assert_valid:
self.assert_valid()

out = b""
for key_u, value_u in self.data.items():
t = bytes.fromhex(key_u)
out += varint.encode(len(t)) + t
t = bytes.fromhex(value_u)
out += varint.encode(len(t)) + t

return out

def assert_valid(self) -> None:
for key, value in self.data.items():
# TODO: verify that pubkey is a valid secp256k1 Point
Expand Down Expand Up @@ -119,25 +198,20 @@ def deserialize(
for key, value in output_map.items():
if key[0:1] == PSBT_OUT_REDEEM_SCRIPT:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.redeem_script == b"", "duplicated redeem_script"
out.redeem_script = value
elif key[0:1] == PSBT_OUT_WITNESS_SCRIPT:
assert len(key) == 1, f"invalid key length: {len(key)}"
assert out.witness_script == b"", "duplicated witness_script"
out.witness_script = value
elif key[0:1] == PSBT_OUT_BIP32_DERIVATION:
assert len(key) == 33 + 1, f"invalid key length: {len(key)}"
# TODO: assert not duplicated?
out.hd_keypaths.add_hd_keypath(key[1:], value[:4], value[4:])
elif key[0:1] == PSBT_OUT_PROPRIETARY:
# TODO: assert not duplicated?
prefix = varint.decode(key[1:])
if prefix not in out.proprietary.data:
out.proprietary.data[prefix] = {}
key = key[1 + len(varint.encode(prefix)) :]
out.proprietary.data[prefix][key.hex()] = value.hex()
else: # unknown keys
# TODO: assert not duplicated?
out.unknown.data[key.hex()] = value.hex()

if assert_valid:
Expand All @@ -148,6 +222,7 @@ def serialize(self, assert_valid: bool = True) -> bytes:

if assert_valid:
self.assert_valid()
assert_valid = False

out = b""

Expand All @@ -157,32 +232,12 @@ def serialize(self, assert_valid: bool = True) -> bytes:
if self.witness_script:
out += b"\x01" + PSBT_OUT_WITNESS_SCRIPT
out += varint.encode(len(self.witness_script)) + self.witness_script
if self.hd_keypaths:
for pubkey, hd_keypath in self.hd_keypaths.hd_keypaths.items():
pubkey_bytes = PSBT_OUT_BIP32_DERIVATION + bytes.fromhex(pubkey)
out += varint.encode(len(pubkey_bytes)) + pubkey_bytes
keypath = bytes.fromhex(hd_keypath["fingerprint"])
keypath += bytes_from_bip32_path(
hd_keypath["derivation_path"], "little"
)
out += varint.encode(len(keypath)) + keypath
if self.proprietary:
for (owner, dictionary) in self.proprietary.data.items():
for key_p, value_p in dictionary.items():
key_bytes = (
PSBT_OUT_PROPRIETARY
+ varint.encode(owner)
+ bytes.fromhex(key_p)
)
out += varint.encode(len(key_bytes)) + key_bytes
t = bytes.fromhex(value_p)
out += varint.encode(len(t)) + t
if self.unknown:
for key_u, value_u in self.unknown.data.items():
t = bytes.fromhex(key_u)
out += varint.encode(len(t)) + t
t = bytes.fromhex(value_u)
out += varint.encode(len(t)) + t
if self.hd_keypaths.hd_keypaths:
out += self.hd_keypaths.serialize(PSBT_OUT_BIP32_DERIVATION, assert_valid)
if self.proprietary.data:
out += self.proprietary.serialize(PSBT_OUT_PROPRIETARY, assert_valid)
if self.unknown.data:
out += self.unknown.serialize(assert_valid)

return out

Expand Down

0 comments on commit 5b832c4

Please sign in to comment.