Skip to content

Commit

Permalink
tests: Test that PSBT_OUT_TAP_TREE is combined correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
achow101 committed Oct 6, 2022
1 parent 7df6e1b commit 22c051c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
22 changes: 20 additions & 2 deletions test/functional/rpc_psbt.py
Expand Up @@ -27,6 +27,7 @@
PSBT_IN_SHA256,
PSBT_IN_HASH160,
PSBT_IN_HASH256,
PSBT_OUT_TAP_TREE,
)
from test_framework.test_framework import BitcoinTestFramework
from test_framework.util import (
Expand Down Expand Up @@ -779,9 +780,18 @@ def test_psbt_input_keys(psbt_input, keys):
self.generate(self.nodes[0], 1)
self.nodes[0].importdescriptors([{"desc": descsum_create("tr({})".format(privkey)), "timestamp":"now"}])

psbt = watchonly.sendall([wallet.getnewaddress()])["psbt"]
psbt = watchonly.sendall([wallet.getnewaddress(), addr])["psbt"]
psbt = self.nodes[0].walletprocesspsbt(psbt)["psbt"]
self.nodes[0].sendrawtransaction(self.nodes[0].finalizepsbt(psbt)["hex"])
txid = self.nodes[0].sendrawtransaction(self.nodes[0].finalizepsbt(psbt)["hex"])
vout = find_vout_for_address(self.nodes[0], txid, addr)

# Make sure tap tree is in psbt
parsed_psbt = PSBT.from_base64(psbt)
assert_greater_than(len(parsed_psbt.o[vout].map[PSBT_OUT_TAP_TREE]), 0)
assert "taproot_tree" in self.nodes[0].decodepsbt(psbt)["outputs"][vout]
parsed_psbt.make_blank()
comb_psbt = self.nodes[0].combinepsbt([psbt, parsed_psbt.to_base64()])
assert_equal(comb_psbt, psbt)

self.log.info("Test that walletprocesspsbt both updates and signs a non-updated psbt containing Taproot inputs")
addr = self.nodes[0].getnewaddress("", "bech32m")
Expand All @@ -793,6 +803,14 @@ def test_psbt_input_keys(psbt_input, keys):
self.nodes[0].sendrawtransaction(rawtx)
self.generate(self.nodes[0], 1)

# Make sure tap tree is not in psbt
parsed_psbt = PSBT.from_base64(psbt)
assert PSBT_OUT_TAP_TREE not in parsed_psbt.o[0].map
assert "taproot_tree" not in self.nodes[0].decodepsbt(psbt)["outputs"][0]
parsed_psbt.make_blank()
comb_psbt = self.nodes[0].combinepsbt([psbt, parsed_psbt.to_base64()])
assert_equal(comb_psbt, psbt)

self.log.info("Test decoding PSBT with per-input preimage types")
# note that the decodepsbt RPC doesn't check whether preimages and hashes match
hash_ripemd160, preimage_ripemd160 = random_bytes(20), random_bytes(50)
Expand Down
9 changes: 9 additions & 0 deletions test/functional/test_framework/psbt.py
Expand Up @@ -123,6 +123,15 @@ def serialize(self):
psbt = [x.serialize() for x in [self.g] + self.i + self.o]
return b"psbt\xff" + b"".join(psbt)

def make_blank(self):
"""
Remove all fields except for PSBT_GLOBAL_UNSIGNED_TX
"""
for m in self.i + self.o:
m.map.clear()

self.g = PSBTMap(map={0: self.g.map[0]})

def to_base64(self):
return base64.b64encode(self.serialize()).decode("utf8")

Expand Down

0 comments on commit 22c051c

Please sign in to comment.