diff --git a/btclib/psbt.py b/btclib/psbt.py index 631cfe31c..5e72741bd 100644 --- a/btclib/psbt.py +++ b/btclib/psbt.py @@ -23,8 +23,8 @@ from . import script, varint from .alias import ScriptToken from .bip32 import bytes_from_bip32_path -from .psbt_in import PartialSigs, PsbtIn, UnknownData -from .psbt_out import HdKeyPaths, PsbtOut +from .psbt_in import PartialSigs, PsbtIn +from .psbt_out import HdKeyPaths, ProprietaryData, PsbtOut, UnknownData from .scriptpubkey import payload_from_scriptPubKey from .tx import Tx from .tx_out import TxOut @@ -49,7 +49,7 @@ class Psbt(DataClassJsonMixin): outputs: List[PsbtOut] = field(default_factory=list) version: int = 0 hd_keypaths: HdKeyPaths = field(default_factory=HdKeyPaths) - proprietary: Dict[int, Dict[str, str]] = field(default_factory=dict) + proprietary: ProprietaryData = field(default_factory=ProprietaryData) unknown: UnknownData = field(default_factory=UnknownData) @classmethod @@ -79,10 +79,10 @@ def deserialize(cls: Type[_Psbt], data: bytes, assert_valid: bool = True) -> _Ps elif key[0:1] == PSBT_GLOBAL_PROPRIETARY: # TODO: assert not duplicated? prefix = varint.decode(key[1:]) - if prefix not in out.proprietary.keys(): - out.proprietary[prefix] = {} + if prefix not in out.proprietary.data: + out.proprietary.data[prefix] = {} key = key[1 + len(varint.encode(prefix)) :] - out.proprietary[prefix][key.hex()] = value.hex() + out.proprietary.data[prefix][key.hex()] = value.hex() else: # unknown keys # TODO: assert not duplicated? out.unknown.data[key.hex()] = value.hex() @@ -122,7 +122,7 @@ def serialize(self, assert_valid: bool = True) -> bytes: ) out += varint.encode(len(keypath)) + keypath if self.proprietary: - for (owner, dictionary) in self.proprietary.items(): + for (owner, dictionary) in self.proprietary.data.items(): for key_p, value_p in dictionary.items(): key_bytes = ( PSBT_GLOBAL_PROPRIETARY @@ -181,7 +181,7 @@ def assert_valid(self) -> None: assert self.version == 0 self.hd_keypaths.assert_valid() - assert isinstance(self.proprietary, dict) + self.proprietary.assert_valid() self.unknown.assert_valid() def assert_signable(self) -> None: @@ -259,9 +259,7 @@ def _combine_field( ) -> None: item = getattr(psbt_map, key) a = getattr(out, key) - if isinstance(item, dict) and a and isinstance(a, dict): - a.update(item) - elif isinstance(item, dict) or item and not a: + if isinstance(item, dict) or item and not a: setattr(out, key, item) elif isinstance(item, PartialSigs) and isinstance(a, PartialSigs): a.sigs.update(item.sigs) @@ -269,6 +267,8 @@ def _combine_field( a.hd_keypaths.update(item.hd_keypaths) elif isinstance(item, UnknownData) and isinstance(a, UnknownData): a.data.update(item.data) + elif isinstance(item, ProprietaryData) and isinstance(a, ProprietaryData): + a.data.update(item.data) elif item: assert item == a, key diff --git a/btclib/psbt_in.py b/btclib/psbt_in.py index 2fb27171f..2bed0336a 100644 --- a/btclib/psbt_in.py +++ b/btclib/psbt_in.py @@ -20,7 +20,7 @@ from . import dsa, varint from .bip32 import bytes_from_bip32_path -from .psbt_out import HdKeyPaths, UnknownData +from .psbt_out import HdKeyPaths, ProprietaryData, UnknownData from .script import SIGHASHES from .tx import Tx from .tx_in import witness_deserialize, witness_serialize @@ -80,7 +80,7 @@ class PsbtIn(DataClassJsonMixin): metadata=config(encoder=lambda val: [v.hex() for v in val]), ) por_commitment: Optional[str] = None - proprietary: Dict[int, Dict[str, str]] = field(default_factory=dict) + proprietary: ProprietaryData = field(default_factory=ProprietaryData) unknown: UnknownData = field(default_factory=UnknownData) @classmethod @@ -131,10 +131,10 @@ def deserialize( elif key[0:1] == PSBT_IN_PROPRIETARY: # TODO: assert not duplicated? prefix = varint.decode(key[1:]) - if prefix not in out.proprietary.keys(): - out.proprietary[prefix] = {} + if prefix not in out.proprietary.data: + out.proprietary.data[prefix] = {} key = key[1 + len(varint.encode(prefix)) :] - out.proprietary[prefix][key.hex()] = value.hex() + out.proprietary.data[prefix][key.hex()] = value.hex() else: # unknown keys # TODO: assert not duplicated? out.unknown.data[key.hex()] = value.hex() @@ -193,7 +193,7 @@ def serialize(self, assert_valid: bool = True) -> bytes: ) out += varint.encode(len(keypath)) + keypath if self.proprietary: - for (owner, dictionary) in self.proprietary.items(): + for (owner, dictionary) in self.proprietary.data.items(): for key_p, value_p in dictionary.items(): key_bytes = ( PSBT_IN_PROPRIETARY @@ -228,5 +228,5 @@ def assert_valid(self) -> None: if self.por_commitment is not None: assert self.por_commitment.encode("utf-8") - assert isinstance(self.proprietary, dict) + self.proprietary.assert_valid() self.unknown.assert_valid() diff --git a/btclib/psbt_out.py b/btclib/psbt_out.py index 1dd535f0e..a0929c4d7 100644 --- a/btclib/psbt_out.py +++ b/btclib/psbt_out.py @@ -71,6 +71,14 @@ def assert_valid(self) -> None: pass +@dataclass +class ProprietaryData(DataClassJsonMixin): + data: Dict[int, Dict[str, str]] = field(default_factory=dict) + + def assert_valid(self) -> None: + pass + + @dataclass class UnknownData(DataClassJsonMixin): data: Dict[str, str] = field(default_factory=dict) @@ -100,7 +108,7 @@ class PsbtOut(DataClassJsonMixin): default=b"", metadata=config(encoder=lambda v: v.hex(), decoder=bytes.fromhex) ) hd_keypaths: HdKeyPaths = field(default_factory=HdKeyPaths) - proprietary: Dict[int, Dict[str, str]] = field(default_factory=dict) + proprietary: ProprietaryData = field(default_factory=ProprietaryData) unknown: UnknownData = field(default_factory=UnknownData) @classmethod @@ -124,10 +132,10 @@ def deserialize( elif key[0:1] == PSBT_OUT_PROPRIETARY: # TODO: assert not duplicated? prefix = varint.decode(key[1:]) - if prefix not in out.proprietary.keys(): - out.proprietary[prefix] = {} + if prefix not in out.proprietary.data: + out.proprietary.data[prefix] = {} key = key[1 + len(varint.encode(prefix)) :] - out.proprietary[prefix][key.hex()] = value.hex() + out.proprietary.data[prefix][key.hex()] = value.hex() else: # unknown keys # TODO: assert not duplicated? out.unknown.data[key.hex()] = value.hex() @@ -159,7 +167,7 @@ def serialize(self, assert_valid: bool = True) -> bytes: ) out += varint.encode(len(keypath)) + keypath if self.proprietary: - for (owner, dictionary) in self.proprietary.items(): + for (owner, dictionary) in self.proprietary.data.items(): for key_p, value_p in dictionary.items(): key_bytes = ( PSBT_OUT_PROPRIETARY @@ -180,5 +188,5 @@ def serialize(self, assert_valid: bool = True) -> bytes: def assert_valid(self) -> None: self.hd_keypaths.assert_valid() - assert isinstance(self.proprietary, dict) + self.proprietary.assert_valid() self.unknown.assert_valid() diff --git a/btclib/tests/generated_files/psbt.json b/btclib/tests/generated_files/psbt.json index 0ba41d5ef..a5f4b0632 100644 --- a/btclib/tests/generated_files/psbt.json +++ b/btclib/tests/generated_files/psbt.json @@ -85,7 +85,9 @@ "final_script_sig": "", "final_script_witness": [], "por_commitment": null, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } @@ -120,7 +122,9 @@ "final_script_sig": "", "final_script_witness": [], "por_commitment": null, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } @@ -138,7 +142,9 @@ } } }, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } @@ -154,7 +160,9 @@ } } }, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } @@ -164,7 +172,9 @@ "hd_keypaths": { "hd_keypaths": {} }, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } diff --git a/btclib/tests/generated_files/psbt_in.json b/btclib/tests/generated_files/psbt_in.json index cccf5a1e1..8404af078 100644 --- a/btclib/tests/generated_files/psbt_in.json +++ b/btclib/tests/generated_files/psbt_in.json @@ -49,7 +49,9 @@ "final_script_sig": "", "final_script_witness": [], "por_commitment": null, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } diff --git a/btclib/tests/generated_files/psbt_out.json b/btclib/tests/generated_files/psbt_out.json index 164a5ece9..be6117774 100644 --- a/btclib/tests/generated_files/psbt_out.json +++ b/btclib/tests/generated_files/psbt_out.json @@ -9,7 +9,9 @@ } } }, - "proprietary": {}, + "proprietary": { + "data": {} + }, "unknown": { "data": {} } diff --git a/btclib/tests/test_psbt.py b/btclib/tests/test_psbt.py index d356025db..74316cd9e 100644 --- a/btclib/tests/test_psbt.py +++ b/btclib/tests/test_psbt.py @@ -397,9 +397,9 @@ def test_proprietary(): psbt_string = "cHNidP8BAJoCAAAAAljoeiG1ba8MI76OcHBFbDNvfLqlyHV5JPVFiHuyq911AAAAAAD/////g40EJ9DsZQpoqka7CwmK6kQiwHGyyng1Kgd5WdB86h0BAAAAAP////8CcKrwCAAAAAAWABTYXCtx0AYLCcmIauuBXlCZHdoSTQDh9QUAAAAAFgAUAK6pouXw+HaliN9VRuh0LR2HAI8AAAAAAAEAuwIAAAABqtc5MQGL0l+ErkALaISL4J23BurCrBgpi6vucatlb4sAAAAASEcwRAIgWPb8fGoz4bMVSNSByCbAFb0wE1qtQs1neQ2rZtKtJDsCIEoc7SYExnNbY5PltBaR3XiwDwxZQvufdRhW+qk4FX26Af7///8CgPD6AgAAAAAXqRQPuUY0IWlrgsgzryQceMF9295JNIfQ8gonAQAAABepFCnKdPigj4GZlCgYXJe12FLkBj9hh2UAAAAiAgKVg785rgpgl0etGZrd1jT6YQhVnWxc05tMIYPxq5bgf0cwRAIgdAGK1BgAl7hzMjwAFXILNoTMgSOJEEjn282bVa1nnJkCIHPTabdA4+tT3O+jOCPIBwUUylWn3ZVE8VfBZ5EyYRGMASICAtq2H/SaFNtqfQKwzR+7ePxLGDErW05U2uTbovv+9TbXSDBFAiEA9hA4swjcHahlo0hSdG8BV3KTQgjG0kRUOTzZm98iF3cCIAVuZ1pnWm0KArhbFOXikHTYolqbV2C+ooFvZhkQoAbqAQEDBAEAAAABBEdSIQKVg785rgpgl0etGZrd1jT6YQhVnWxc05tMIYPxq5bgfyEC2rYf9JoU22p9ArDNH7t4/EsYMStbTlTa5Nui+/71NtdSriIGApWDvzmuCmCXR60Zmt3WNPphCFWdbFzTm0whg/GrluB/ENkMak8AAACAAAAAgAAAAIAiBgLath/0mhTban0CsM0fu3j8SxgxK1tOVNrk26L7/vU21xDZDGpPAAAAgAAAAIABAACAAAEBIADC6wsAAAAAF6kUt/X69A49QKWkWbHbNTXyty+pIeiHIgIDCJ3BDHrG21T5EymvYXMz2ziM6tDCMfcjN50bmQMLAtxHMEQCIGLrelVhB6fHP0WsSrWh3d9vcHX7EnWWmn84Pv/3hLyyAiAMBdu3Rw2/LwhVfdNWxzJcHtMJE+mWzThAlF2xIijaXwEiAgI63ZBPPW3PWd25BrDe4jUpt/+57VDl6GFRkmhgIh8Oc0cwRAIgZfRbpZmLWaJ//hp77QFq8fH5DVSzqo90UKpfVqJRA70CIH9yRwOtHtuWaAsoS1bU/8uI9/t1nqu+CKow8puFE4PSAQEDBAEAAAABBCIAIIwjUxc3Q7WV37Sge3K6jkLjeX2nTof+fZ10l+OyAokDAQVHUiEDCJ3BDHrG21T5EymvYXMz2ziM6tDCMfcjN50bmQMLAtwhAjrdkE89bc9Z3bkGsN7iNSm3/7ntUOXoYVGSaGAiHw5zUq4iBgI63ZBPPW3PWd25BrDe4jUpt/+57VDl6GFRkmhgIh8OcxDZDGpPAAAAgAAAAIADAACAIgYDCJ3BDHrG21T5EymvYXMz2ziM6tDCMfcjN50bmQMLAtwQ2QxqTwAAAIAAAACAAgAAgAAiAgOppMN/WZbTqiXbrGtXCvBlA5RJKUJGCzVHU+2e7KWHcRDZDGpPAAAAgAAAAIAEAACAACICAn9jmXV9Lv9VoTatAsaEsYOLZVbl8bazQoKpS2tQBRCWENkMak8AAACAAAAAgAUAAIAA" psbt = Psbt.decode(psbt_string) - psbt.inputs[0].proprietary = {300: {"00ff": "00ff"}} - psbt.outputs[1].proprietary = {300: {"ff00": "ff00"}} - psbt.proprietary = {300: {"aaaa": "bbbb"}} + psbt.inputs[0].proprietary.data[300] = {"00ff": "00ff"} + psbt.outputs[1].proprietary.data[300] = {"ff00": "ff00"} + psbt.proprietary.data[300] = {"aaaa": "bbbb"} assert Psbt.decode(psbt.encode()) == psbt