Skip to content

Commit

Permalink
Begin work on block streaming.
Browse files Browse the repository at this point in the history
  • Loading branch information
Neil Booth committed Feb 10, 2021
1 parent 75e4738 commit 2d23dd4
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 15 deletions.
151 changes: 139 additions & 12 deletions electrumx/lib/tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
'''Transaction-related classes and functions.'''

from collections import namedtuple
from io import BytesIO
from struct import error as struct_error

from electrumx.lib.hash import double_sha256, hash_to_hex_str
from electrumx.lib.hash import double_sha256, hash_to_hex_str, _sha256
from electrumx.lib.util import (
unpack_le_int32_from, unpack_le_int64_from, unpack_le_uint16_from,
unpack_be_uint16_from,
unpack_le_uint32_from, unpack_le_uint64_from, pack_le_int32, pack_varint,
pack_le_uint32, pack_le_int64, pack_varbytes,
pack_le_uint32, pack_le_int64, pack_varbytes, unpack_byte,
unpack_le_uint16, unpack_le_uint32, unpack_le_uint64, unpack_le_int32, unpack_le_int64
)

ZERO = bytes(32)
Expand Down Expand Up @@ -130,8 +132,8 @@ def read_tx_block(self):
return [read() for _ in range(self._read_varint())]

def _read_inputs(self):
read_input = self._read_input
return [read_input() for i in range(self._read_varint())]
read_input_ = self._read_input
return [read_input_() for i in range(self._read_varint())]

def _read_input(self):
return TxInput(
Expand All @@ -142,8 +144,8 @@ def _read_input(self):
)

def _read_outputs(self):
read_output = self._read_output
return [read_output() for i in range(self._read_varint())]
read_output_ = self._read_output
return [read_output_() for i in range(self._read_varint())]

def _read_output(self):
return TxOutput(
Expand Down Expand Up @@ -191,11 +193,6 @@ def _read_le_uint16(self):
self.cursor += 2
return result

def _read_be_uint16(self):
result, = unpack_be_uint16_from(self.binary, self.cursor)
self.cursor += 2
return result

def _read_le_uint32(self):
result, = unpack_le_uint32_from(self.binary, self.cursor)
self.cursor += 4
Expand All @@ -205,3 +202,133 @@ def _read_le_uint64(self):
result, = unpack_le_uint64_from(self.binary, self.cursor)
self.cursor += 8
return result


class TxStream:

def __init__(self, fetch_next):
self.fetch_next = fetch_next
self.buf = b''
self.start = 0
self.cursor = 0
self._read = BytesIO(self.buf).read
self._hash = None

async def read(self, n):
data = self._read(n)
dlen = len(data)
self.cursor += dlen
if dlen == n:
return data

parts = [data]
n -= dlen
while n:
self.update_hash()
self.start = 0
self.buf = await self.fetch_next()
self._read = BytesIO(self.buf).read
data = self._read(n)
dlen = len(data)
parts.append(data)
n -= dlen
self.cursor += dlen

return b''.join(parts)

def update_hash(self):
if self._hash is None:
self._hash = _sha256()
start = self.start
self.start = self.cursor
self._hash.update(self.buf[start: self.cursor])

def get_hash(self):
self.update_hash()
result = self._hash.digest()
self._hash = None
return result

async def read_tx(self):
tx = await read_tx(self.read)
return tx, self.get_hash()


# Stream operations

async def read_le_int32(read):
result, = unpack_le_int32(await read(4))
return result


async def read_le_int64(read):
result, = unpack_le_int64(await read(8))
return result


async def read_le_uint16(read):
result, = unpack_le_uint16(await read(2))
return result


async def read_le_uint32(read):
result, = unpack_le_uint32(await read(4))
return result


async def read_le_uint64(read):
result, = unpack_le_uint64(await read(8))
return result


async def read_varint(read):
# read_byte is supported by mmap objects but not BytesIO
n, = unpack_byte(await read(1))
if n < 253:
return n
if n == 253:
return await read_le_uint16(read)
if n == 254:
return await read_le_uint32(read)
return await read_le_uint64(read)


async def read_varbytes(read):
n = await read_varint(read)
result = await read(n)
if len(result) != n:
raise struct_error(f'varbytes requires a buffer of {n:,d} bytes')
return result


async def read_list(read, read_one):
'''Return a list of items.
Each item is read with read_one, the stream begins with a count of the items.'''
return [await read_one(read) for _ in range(await read_varint(read))]


async def read_tx(read):
'''Return a deserialized transaction.'''
return Tx(
await read_le_int32(read), # version
await read_list(read, read_input), # inputs
await read_list(read, read_output), # outputs
await read_le_uint32(read) # locktime
)


async def read_input(read):
return TxInput(
await read(32), # prev_hash
await read_le_uint32(read), # prev_idx
await read_varbytes(read), # script
await read_le_uint32(read) # sequence
)


async def read_output(read):
return TxOutput(
await read_le_int64(read), # value
await read_varbytes(read), # pk_script
)
6 changes: 5 additions & 1 deletion electrumx/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,13 @@ def protocol_version(client_req, min_tuple, max_tuple):
unpack_be_uint16_from = struct_be_H.unpack_from
unpack_be_uint32_from = struct_be_I.unpack_from

unpack_be_uint32 = struct_be_I.unpack
unpack_le_int32 = struct_le_i.unpack
unpack_le_int64 = struct_le_q.unpack
unpack_le_uint16 = struct_le_H.unpack
unpack_le_uint32 = struct_le_I.unpack
unpack_le_uint64 = struct_le_Q.unpack
unpack_be_uint32 = struct_be_I.unpack
unpack_byte = structB.unpack

pack_le_int32 = struct_le_i.pack
pack_le_int64 = struct_le_q.pack
Expand Down
54 changes: 52 additions & 2 deletions tests/lib/test_tx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import electrumx.lib.tx as tx_lib
import pytest

from electrumx.lib.tx import Deserializer, TxStream
from electrumx.lib.hash import sha256
import random


tests = [
"020000000192809f0b234cb850d71d020e678e93f074648ed0df5affd0c46d3bcb177f"
Expand Down Expand Up @@ -32,6 +37,51 @@
def test_tx_serialiazation():
for test in tests:
test = bytes.fromhex(test)
deser = tx_lib.Deserializer(test)
deser = Deserializer(test)
tx = deser.read_tx()
assert tx.serialize() == test


class EndOfStream(Exception):
pass


class StreamedData:

def __init__(self, data):
self.data = data
self.cursor = 0

async def fetch_next(self):
remaining = len(self.data) - self.cursor
if remaining == 0:
raise EndOfStream
size = random.randrange(0, remaining) + 1
cursor = self.cursor
self.cursor += size
return self.data[cursor: self.cursor]


class TestTxStream:

@pytest.mark.asyncio
async def test_simple(self):
data = bytes(range(64))
sdata = StreamedData(data)
stream = TxStream(sdata.fetch_next)
expected_hash = sha256(data)
stream_data = await stream.read(len(data))
stream_hash = stream.get_hash()
assert stream_data.hex() == data.hex()
assert stream_hash.hex() == expected_hash.hex()

@pytest.mark.asyncio
@pytest.mark.parametrize("raw_tx_hex", tests)
async def test_read_tx(self, raw_tx_hex):
raw_tx = bytes.fromhex(raw_tx_hex)
sdata = StreamedData(raw_tx)
stream = TxStream(sdata.fetch_next)
expected_tx_hash = sha256(raw_tx)
tx, tx_hash = await stream.read_tx()
assert tx.serialize().hex() == raw_tx_hex
assert tx_hash == expected_tx_hash

0 comments on commit 2d23dd4

Please sign in to comment.