103 changes: 96 additions & 7 deletions test/functional/p2p_segwit.py
Expand Up @@ -25,6 +25,7 @@
MSG_BLOCK,
MSG_TX,
MSG_WITNESS_FLAG,
MSG_WTX,
NODE_NETWORK,
NODE_WITNESS,
msg_no_witness_block,
Expand All @@ -34,6 +35,7 @@
msg_tx,
msg_block,
msg_no_witness_tx,
msg_verack,
ser_uint256,
ser_vector,
sha256,
Expand Down Expand Up @@ -81,6 +83,7 @@
softfork_active,
hex_str_to_bytes,
assert_raises_rpc_error,
wait_until,
)

# The versionbit bit used to signal activation of SegWit
Expand Down Expand Up @@ -143,25 +146,47 @@ def test_witness_block(node, p2p, block, accepted, with_witness=True, reason=Non


class TestP2PConn(P2PInterface):
def __init__(self):
def __init__(self, wtxidrelay=False):
super().__init__()
self.getdataset = set()
self.last_wtxidrelay = []
self.lastgetdata = []
self.wtxidrelay = wtxidrelay

# Avoid sending out msg_getdata in the mininode thread as a reply to invs.
# They are not needed and would only lead to races because we send msg_getdata out in the test thread
def on_inv(self, message):
pass

def on_version(self, message):
if self.wtxidrelay:
super().on_version(message)
else:
self.send_message(msg_verack())
self.nServices = message.nServices

def on_getdata(self, message):
self.lastgetdata = message.inv
for inv in message.inv:
self.getdataset.add(inv.hash)

def announce_tx_and_wait_for_getdata(self, tx, timeout=60, success=True):
def on_wtxidrelay(self, message):
self.last_wtxidrelay.append(message)

def announce_tx_and_wait_for_getdata(self, tx, timeout=60, success=True, use_wtxid=False):
with mininode_lock:
self.last_message.pop("getdata", None)
self.send_message(msg_inv(inv=[CInv(MSG_TX, tx.sha256)]))
if use_wtxid:
wtxid = tx.calc_sha256(True)
self.send_message(msg_inv(inv=[CInv(MSG_WTX, wtxid)]))
else:
self.send_message(msg_inv(inv=[CInv(MSG_TX, tx.sha256)]))

if success:
self.wait_for_getdata([tx.sha256], timeout)
if use_wtxid:
self.wait_for_getdata([wtxid], timeout)
else:
self.wait_for_getdata([tx.sha256], timeout)
else:
time.sleep(timeout)
assert not self.last_message.get("getdata")
Expand Down Expand Up @@ -277,6 +302,7 @@ def run_test(self):
self.test_upgrade_after_activation()
self.test_witness_sigops()
self.test_superfluous_witness()
self.test_wtxid_relay()

# Individual tests

Expand Down Expand Up @@ -1270,7 +1296,6 @@ def test_tx_relay_after_segwit_activation(self):
test_transaction_acceptance(self.nodes[0], self.test_node, tx, with_witness=True, accepted=False)

# Verify that removing the witness succeeds.
self.test_node.announce_tx_and_wait_for_getdata(tx)
test_transaction_acceptance(self.nodes[0], self.test_node, tx, with_witness=False, accepted=True)

# Now try to add extra witness data to a valid witness tx.
Expand All @@ -1297,8 +1322,6 @@ def test_tx_relay_after_segwit_activation(self):
# Node will not be blinded to the transaction
self.std_node.announce_tx_and_wait_for_getdata(tx3)
test_transaction_acceptance(self.nodes[1], self.std_node, tx3, True, False, 'tx-size')
self.std_node.announce_tx_and_wait_for_getdata(tx3)
test_transaction_acceptance(self.nodes[1], self.std_node, tx3, True, False, 'tx-size')

# Remove witness stuffing, instead add extra witness push on stack
tx3.vout[0] = CTxOut(tx2.vout[0].nValue - 1000, CScript([OP_TRUE, OP_DROP] * 15 + [OP_TRUE]))
Expand Down Expand Up @@ -2016,6 +2039,11 @@ def test_witness_sigops(self):

# TODO: test p2sh sigop counting

# Cleanup and prep for next test
self.utxo.pop(0)
self.utxo.append(UTXO(tx2.sha256, 0, tx2.vout[0].nValue))

@subtest # type: ignore
def test_superfluous_witness(self):
# Serialization of tx that puts witness flag to 3 always
def serialize_with_bogus_witness(tx):
Expand Down Expand Up @@ -2059,6 +2087,67 @@ def serialize(self):
with self.nodes[0].assert_debug_log(['Unknown transaction optional data']):
self.nodes[0].p2p.send_and_ping(msg_bogus_tx(tx))

@subtest # type: ignore
def test_wtxid_relay(self):
# Use brand new nodes to avoid contamination from earlier tests
self.wtx_node = self.nodes[0].add_p2p_connection(TestP2PConn(wtxidrelay=True), services=NODE_NETWORK | NODE_WITNESS)
self.tx_node = self.nodes[0].add_p2p_connection(TestP2PConn(wtxidrelay=False), services=NODE_NETWORK | NODE_WITNESS)

# Check wtxidrelay feature negotiation message through connecting a new peer
def received_wtxidrelay():
return (len(self.wtx_node.last_wtxidrelay) > 0)
wait_until(received_wtxidrelay, timeout=60, lock=mininode_lock)

# Create a Segwit output from the latest UTXO
# and announce it to the network
witness_program = CScript([OP_TRUE])
witness_hash = sha256(witness_program)
script_pubkey = CScript([OP_0, witness_hash])

tx = CTransaction()
tx.vin.append(CTxIn(COutPoint(self.utxo[0].sha256, self.utxo[0].n), b""))
tx.vout.append(CTxOut(self.utxo[0].nValue - 1000, script_pubkey))
tx.rehash()

# Create a Segwit transaction
tx2 = CTransaction()
tx2.vin.append(CTxIn(COutPoint(tx.sha256, 0), b""))
tx2.vout.append(CTxOut(tx.vout[0].nValue - 1000, script_pubkey))
tx2.wit.vtxinwit.append(CTxInWitness())
tx2.wit.vtxinwit[0].scriptWitness.stack = [witness_program]
tx2.rehash()

# Announce Segwit transaction with wtxid
# and wait for getdata
self.wtx_node.announce_tx_and_wait_for_getdata(tx2, use_wtxid=True)
with mininode_lock:
lgd = self.wtx_node.lastgetdata[:]
assert_equal(lgd, [CInv(MSG_WTX, tx2.calc_sha256(True))])

# Announce Segwit transaction from non wtxidrelay peer
# and wait for getdata
self.tx_node.announce_tx_and_wait_for_getdata(tx2, use_wtxid=False)
with mininode_lock:
lgd = self.tx_node.lastgetdata[:]
assert_equal(lgd, [CInv(MSG_TX|MSG_WITNESS_FLAG, tx2.sha256)])

# Send tx2 through; it's an orphan so won't be accepted
with mininode_lock:
self.tx_node.last_message.pop("getdata", None)
test_transaction_acceptance(self.nodes[0], self.tx_node, tx2, with_witness=True, accepted=False)

# Expect a request for parent (tx) due to use of non-WTX peer
self.tx_node.wait_for_getdata([tx.sha256], 60)
with mininode_lock:
lgd = self.tx_node.lastgetdata[:]
assert_equal(lgd, [CInv(MSG_TX|MSG_WITNESS_FLAG, tx.sha256)])

# Send tx through
test_transaction_acceptance(self.nodes[0], self.tx_node, tx, with_witness=False, accepted=True)

# Check tx2 is there now
assert_equal(tx2.hash in self.nodes[0].getrawmempool(), True)


if __name__ == '__main__':
SegWitTest().main()
12 changes: 7 additions & 5 deletions test/functional/p2p_tx_download.py
Expand Up @@ -12,6 +12,7 @@
FromHex,
MSG_TX,
MSG_TYPE_MASK,
MSG_WTX,
msg_inv,
msg_notfound,
)
Expand All @@ -36,20 +37,21 @@ def __init__(self):

def on_getdata(self, message):
for i in message.inv:
if i.type & MSG_TYPE_MASK == MSG_TX:
if i.type & MSG_TYPE_MASK == MSG_TX or i.type & MSG_TYPE_MASK == MSG_WTX:
self.tx_getdata_count += 1


# Constants from net_processing
GETDATA_TX_INTERVAL = 60 # seconds
MAX_GETDATA_RANDOM_DELAY = 2 # seconds
INBOUND_PEER_TX_DELAY = 2 # seconds
TXID_RELAY_DELAY = 2 # seconds
MAX_GETDATA_IN_FLIGHT = 100
TX_EXPIRY_INTERVAL = GETDATA_TX_INTERVAL * 10

# Python test constants
NUM_INBOUND = 10
MAX_GETDATA_INBOUND_WAIT = GETDATA_TX_INTERVAL + MAX_GETDATA_RANDOM_DELAY + INBOUND_PEER_TX_DELAY
MAX_GETDATA_INBOUND_WAIT = GETDATA_TX_INTERVAL + MAX_GETDATA_RANDOM_DELAY + INBOUND_PEER_TX_DELAY + TXID_RELAY_DELAY


class TxDownloadTest(BitcoinTestFramework):
Expand All @@ -63,7 +65,7 @@ def test_tx_requests(self):
txid = 0xdeadbeef

self.log.info("Announce the txid from each incoming peer to node 0")
msg = msg_inv([CInv(t=MSG_TX, h=txid)])
msg = msg_inv([CInv(t=MSG_WTX, h=txid)])
for p in self.nodes[0].p2ps:
p.send_and_ping(msg)

Expand Down Expand Up @@ -135,13 +137,13 @@ def test_in_flight_max(self):
with mininode_lock:
p.tx_getdata_count = 0

p.send_message(msg_inv([CInv(t=MSG_TX, h=i) for i in txids]))
p.send_message(msg_inv([CInv(t=MSG_WTX, h=i) for i in txids]))
wait_until(lambda: p.tx_getdata_count >= MAX_GETDATA_IN_FLIGHT, lock=mininode_lock)
with mininode_lock:
assert_equal(p.tx_getdata_count, MAX_GETDATA_IN_FLIGHT)

self.log.info("Now check that if we send a NOTFOUND for a transaction, we'll get one more request")
p.send_message(msg_notfound(vec=[CInv(t=MSG_TX, h=txids[0])]))
p.send_message(msg_notfound(vec=[CInv(t=MSG_WTX, h=txids[0])]))
wait_until(lambda: p.tx_getdata_count >= MAX_GETDATA_IN_FLIGHT + 1, timeout=10, lock=mininode_lock)
with mininode_lock:
assert_equal(p.tx_getdata_count, MAX_GETDATA_IN_FLIGHT + 1)
Expand Down
25 changes: 23 additions & 2 deletions test/functional/test_framework/messages.py
Expand Up @@ -31,7 +31,7 @@
from test_framework.util import hex_str_to_bytes, assert_equal

MIN_VERSION_SUPPORTED = 60001
MY_VERSION = 70014 # past bip-31 for ping/pong
MY_VERSION = 70016 # past wtxid relay
MY_SUBVERSION = b"/python-mininode-tester:0.0.3/"
MY_RELAY = 1 # from version 70001 onwards, fRelay should be appended to version messages (BIP37)

Expand Down Expand Up @@ -59,6 +59,7 @@
MSG_BLOCK = 2
MSG_FILTERED_BLOCK = 3
MSG_CMPCT_BLOCK = 4
MSG_WTX = 5
MSG_WITNESS_FLAG = 1 << 30
MSG_TYPE_MASK = 0xffffffff >> 2

Expand Down Expand Up @@ -242,7 +243,8 @@ class CInv:
MSG_TX | MSG_WITNESS_FLAG: "WitnessTx",
MSG_BLOCK | MSG_WITNESS_FLAG: "WitnessBlock",
MSG_FILTERED_BLOCK: "filtered Block",
4: "CompactBlock"
4: "CompactBlock",
5: "WTX",
}

def __init__(self, t=0, h=0):
Expand All @@ -263,6 +265,9 @@ def __repr__(self):
return "CInv(type=%s hash=%064x)" \
% (self.typemap[self.type], self.hash)

def __eq__(self, other):
return isinstance(other, CInv) and self.hash == other.hash and self.type == other.type


class CBlockLocator:
__slots__ = ("nVersion", "vHave")
Expand Down Expand Up @@ -1124,6 +1129,22 @@ def serialize(self):
def __repr__(self):
return "msg_tx(tx=%s)" % (repr(self.tx))

class msg_wtxidrelay:
__slots__ = ()
msgtype = b"wtxidrelay"

def __init__(self):
pass

def deserialize(self, f):
pass

def serialize(self):
return b""

def __repr__(self):
return "msg_wtxidrelay()"


class msg_no_witness_tx(msg_tx):
__slots__ = ()
Expand Down
8 changes: 7 additions & 1 deletion test/functional/test_framework/mininode.py
Expand Up @@ -59,6 +59,8 @@
MSG_TYPE_MASK,
msg_verack,
msg_version,
MSG_WTX,
msg_wtxidrelay,
NODE_NETWORK,
NODE_WITNESS,
sha256,
Expand Down Expand Up @@ -96,6 +98,7 @@
b"tx": msg_tx,
b"verack": msg_verack,
b"version": msg_version,
b"wtxidrelay": msg_wtxidrelay,
}

MAGIC_BYTES = {
Expand Down Expand Up @@ -356,6 +359,7 @@ def on_pong(self, message): pass
def on_sendcmpct(self, message): pass
def on_sendheaders(self, message): pass
def on_tx(self, message): pass
def on_wtxidrelay(self, message): pass

def on_inv(self, message):
want = msg_getdata()
Expand All @@ -373,6 +377,8 @@ def on_verack(self, message):

def on_version(self, message):
assert message.nVersion >= MIN_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format(message.nVersion, MIN_VERSION_SUPPORTED)
if message.nVersion >= 70016:
self.send_message(msg_wtxidrelay())
self.send_message(msg_verack())
self.nServices = message.nServices

Expand Down Expand Up @@ -654,7 +660,7 @@ def on_inv(self, message):
super().on_inv(message) # Send getdata in response.
# Store how many times invs have been received for each tx.
for i in message.inv:
if i.type == MSG_TX:
if (i.type == MSG_TX) or (i.type == MSG_WTX):
# save txid
self.tx_invs_received[i.hash] += 1

Expand Down