Skip to content

Commit

Permalink
Improve generation inputs handling
Browse files Browse the repository at this point in the history
  • Loading branch information
erasmospunk committed Aug 16, 2018
1 parent 8ab0465 commit a6d46fa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 60 deletions.
54 changes: 11 additions & 43 deletions electrumx/lib/tx.py
Expand Up @@ -31,20 +31,23 @@

from electrumx.lib.hash import sha256, double_sha256, hash_to_hex_str
from electrumx.lib.util import (
cachedproperty, unpack_le_int32_from, unpack_le_int64_from,
unpack_le_uint16_from, unpack_le_uint32_from, unpack_le_uint64_from,
pack_le_int32, pack_varint, pack_le_uint32, pack_le_uint32, pack_le_int64,
pack_varbytes,
unpack_le_int32_from, unpack_le_int64_from, unpack_le_uint16_from,
unpack_le_uint32_from, unpack_le_uint64_from, pack_le_int32, pack_varint,
pack_le_uint32, pack_le_int64, pack_varbytes,
)

ZERO = bytes(32)
MINUS_1 = 4294967295


def is_gen_outpoint(hash, index):
'''Test if an outpoint is a generation/coinbase like'''
return index == MINUS_1 and hash == ZERO


class Tx(namedtuple("Tx", "version inputs outputs locktime")):
'''Class representing a transaction.'''

@cachedproperty
def is_generation(self):
return self.inputs[0].is_generation

def serialize(self):
return b''.join((
pack_le_int32(self.version),
Expand All @@ -58,15 +61,6 @@ def serialize(self):

class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")):
'''Class representing a transaction input.'''

ZERO = bytes(32)
MINUS_1 = 4294967295

@cachedproperty
def is_generation(self):
return (self.prev_idx == TxInput.MINUS_1 and
self.prev_hash == TxInput.ZERO)

def __str__(self):
script = self.script.hex()
prev_hash = hash_to_hex_str(self.prev_hash)
Expand Down Expand Up @@ -214,10 +208,6 @@ class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs "
"witness locktime")):
'''Class representing a SegWit transaction.'''

@cachedproperty
def is_generation(self):
return self.inputs[0].is_generation


class DeserializerSegWit(Deserializer):

Expand Down Expand Up @@ -326,10 +316,6 @@ class DeserializerEquihashSegWit(DeserializerSegWit, DeserializerEquihash):
class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")):
'''Class representing a JoinSplit transaction.'''

@cachedproperty
def is_generation(self):
return self.inputs[0].is_generation if len(self.inputs) > 0 else False


class DeserializerZcash(DeserializerEquihash):
def read_tx(self):
Expand Down Expand Up @@ -360,10 +346,6 @@ def read_tx(self):
class TxTime(namedtuple("Tx", "version time inputs outputs locktime")):
'''Class representing transaction that has a time field.'''

@cachedproperty
def is_generation(self):
return self.inputs[0].is_generation


class DeserializerTxTime(Deserializer):
def read_tx(self):
Expand Down Expand Up @@ -445,11 +427,6 @@ class DeserializerGroestlcoin(DeserializerSegWit):
class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")):
'''Class representing a Decred transaction input.'''

@cachedproperty
def is_generation(self):
return (self.prev_idx == TxInput.MINUS_1 and
self.prev_hash == TxInput.ZERO)

def __str__(self):
prev_hash = hash_to_hex_str(self.prev_hash)
return ("Input({}, {:d}, tree={}, sequence={:d})"
Expand All @@ -465,10 +442,6 @@ class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry "
"witness")):
'''Class representing a Decred transaction.'''

@cachedproperty
def is_generation(self):
return self.inputs[0].is_generation


class DeserializerDecred(Deserializer):
@staticmethod
Expand Down Expand Up @@ -541,11 +514,6 @@ def _read_tx_parts(self, produce_hash=True):
end_prefix = self.cursor
witness = self._read_witness(len(inputs))

# Drop the coinbase-like input from a vote tx as it creates problems
# with UTXOs lookups and mempool management
if inputs[0].is_generation and len(inputs) > 1:
inputs = inputs[1:]

if produce_hash:
# TxSerializeNoWitness << 16 == 0x10000
no_witness_header = pack_le_uint32(0x10000 | (version & 0xffff))
Expand Down
27 changes: 15 additions & 12 deletions electrumx/server/block_processor.py
Expand Up @@ -18,6 +18,7 @@
from aiorpcx import TaskGroup, run_in_thread

import electrumx
from electrumx.lib.tx import is_gen_outpoint
from electrumx.server.daemon import DaemonError
from electrumx.lib.hash import hash_to_hex_str, HASHX_LEN
from electrumx.lib.util import chunks, class_logger
Expand Down Expand Up @@ -411,11 +412,12 @@ def advance_txs(self, txs):
tx_numb = s_pack('<I', tx_num)

# Spend the inputs
if not tx.is_generation:
for txin in tx.inputs:
cache_value = spend_utxo(txin.prev_hash, txin.prev_idx)
undo_info_append(cache_value)
append_hashX(cache_value[:-12])
for txin in tx.inputs:
if is_gen_outpoint(txin.prev_hash, txin.prev_idx):
continue
cache_value = spend_utxo(txin.prev_hash, txin.prev_idx)
undo_info_append(cache_value)
append_hashX(cache_value[:-12])

# Add the new UTXOs
for idx, txout in enumerate(tx.outputs):
Expand Down Expand Up @@ -490,13 +492,14 @@ def backup_txs(self, txs):
touched.add(cache_value[:-12])

# Restore the inputs
if not tx.is_generation:
for txin in reversed(tx.inputs):
n -= undo_entry_len
undo_item = undo_info[n:n + undo_entry_len]
put_utxo(txin.prev_hash + s_pack('<H', txin.prev_idx),
undo_item)
touched.add(undo_item[:-12])
for txin in reversed(tx.inputs):
if is_gen_outpoint(txin.prev_hash, txin.prev_idx):
continue
n -= undo_entry_len
undo_item = undo_info[n:n + undo_entry_len]
put_utxo(txin.prev_hash + s_pack('<H', txin.prev_idx),
undo_item)
touched.add(undo_item[:-12])

assert n == 0
self.tx_count -= len(txs)
Expand Down
16 changes: 12 additions & 4 deletions electrumx/server/mempool.py
Expand Up @@ -17,6 +17,7 @@
from aiorpcx import TaskGroup, run_in_thread, sleep

from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash
from electrumx.lib.tx import is_gen_outpoint
from electrumx.lib.util import class_logger, chunks
from electrumx.server.db import UTXO

Expand Down Expand Up @@ -172,6 +173,9 @@ def _accept_transactions(self, tx_map, utxo_map, touched):
in_pairs = []
try:
for prevout in tx.prevouts:
# Skip generation like prevouts
if is_gen_outpoint(*prevout):
continue
utxo = utxo_map.get(prevout)
if not utxo:
prev_hash, prev_index = prevout
Expand All @@ -187,8 +191,10 @@ def _accept_transactions(self, tx_map, utxo_map, touched):

# Save the in_pairs, compute the fee and accept the TX
tx.in_pairs = tuple(in_pairs)
tx.fee = (sum(v for hashX, v in tx.in_pairs) -
sum(v for hashX, v in tx.out_pairs))
# Avoid negative fees if dealing with generation-like transactions
# because some in_parts would be missing
tx.fee = max(0, (sum(v for _, v in tx.in_pairs) -
sum(v for _, v in tx.out_pairs)))
txs[hash] = tx

for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs):
Expand Down Expand Up @@ -285,10 +291,12 @@ def deserialize_txs(): # This function is pure
# Determine all prevouts not in the mempool, and fetch the
# UTXO information from the database. Failed prevout lookups
# return None - concurrent database updates happen - which is
# relied upon by _accept_transactions
# relied upon by _accept_transactions. Ignore prevouts that are
# generation-like.
prevouts = tuple(prevout for tx in tx_map.values()
for prevout in tx.prevouts
if prevout[0] not in all_hashes)
if (prevout[0] not in all_hashes and
not is_gen_outpoint(*prevout)))
utxos = await self.api.lookup_utxos(prevouts)
utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)}

Expand Down
15 changes: 14 additions & 1 deletion tests/server/test_mempool.py
Expand Up @@ -10,7 +10,7 @@
from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.lib.coins import BitcoinCash
from electrumx.lib.hash import HASHX_LEN, hex_str_to_hash, hash_to_hex_str
from electrumx.lib.tx import Tx, TxInput, TxOutput
from electrumx.lib.tx import Tx, TxInput, TxOutput, is_gen_outpoint
from electrumx.lib.util import make_logger


Expand All @@ -32,6 +32,9 @@ def random_tx(hash160s, utxos):
inputs.append(TxInput(prevout[0], prevout[1], b'', 4294967295))
input_value += value

# Add a generation/coinbase like input that is present in some coins
inputs.append(TxInput(bytes(32), 4294967295, b'', 4294967295))

fee = min(input_value, randrange(500))
input_value -= fee
outputs = []
Expand Down Expand Up @@ -105,6 +108,8 @@ def balance_deltas(self):
for tx_hash, tx in self.txs.items():
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if is_gen_outpoint(input.prev_hash, input.prev_idx):
continue
if prevout in utxos:
utxos.pop(prevout)
else:
Expand All @@ -121,6 +126,8 @@ def spends(self):
for tx_hash, tx in self.txs.items():
for n, input in enumerate(tx.inputs):
prevout = (input.prev_hash, input.prev_idx)
if is_gen_outpoint(input.prev_hash, input.prev_idx):
continue
if prevout in utxos:
hashX, value = utxos.pop(prevout)
else:
Expand All @@ -137,6 +144,8 @@ def summaries(self):
hashXs = set()
has_ui = False
for n, input in enumerate(tx.inputs):
if is_gen_outpoint(input.prev_hash, input.prev_idx):
continue
has_ui = has_ui or (input.prev_hash in self.txs)
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
Expand All @@ -161,6 +170,8 @@ def touched(self, tx_hashes):
for tx_hash in tx_hashes:
tx = self.txs[tx_hash]
for n, input in enumerate(tx.inputs):
if is_gen_outpoint(input.prev_hash, input.prev_idx):
continue
prevout = (input.prev_hash, input.prev_idx)
if prevout in utxos:
hashX, value = utxos[prevout]
Expand Down Expand Up @@ -471,6 +482,8 @@ async def test_notifications():
api._height = new_height
api.db_utxos.update(first_utxos)
for spend in first_spends:
if is_gen_outpoint(*spend):
continue
del api.db_utxos[spend]
api.raw_txs = {hash: raw_txs[hash] for hash in second_hashes}
api.txs = {hash: txs[hash] for hash in second_hashes}
Expand Down

0 comments on commit a6d46fa

Please sign in to comment.