Skip to content

Commit

Permalink
cleaned up the deserialize method and frozen some dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed Dec 13, 2020
1 parent 713a6ac commit 9e123a7
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 96 deletions.
25 changes: 9 additions & 16 deletions btclib/bms.py
Expand Up @@ -156,7 +156,7 @@
_Sig = TypeVar("_Sig", bound="Sig")


@dataclass
@dataclass(frozen=True)
class Sig(DataClassJsonMixin):
# 1 byte
rf: int = 0
Expand Down Expand Up @@ -192,21 +192,14 @@ def deserialize(
) -> _Sig:

stream = bytesio_from_binarydata(data)
sig = cls(check_validity=False)

sig.rf = int.from_bytes(stream.read(1), byteorder="big", signed=False)

nsize = sig.dsa_sig.ec.nsize
sig.dsa_sig.r = int.from_bytes(
stream.read(nsize), byteorder="big", signed=False
)
sig.dsa_sig.s = int.from_bytes(
stream.read(nsize), byteorder="big", signed=False
)

if assert_valid:
sig.assert_valid()
return sig
rf = int.from_bytes(stream.read(1), byteorder="big", signed=False)
ec = secp256k1
nsize = ec.nsize
r = int.from_bytes(stream.read(nsize), byteorder="big", signed=False)
s = int.from_bytes(stream.read(nsize), byteorder="big", signed=False)
dsa_sig = dsa.Sig(r, s, ec, check_validity=False)

return cls(rf, dsa_sig, check_validity=assert_valid)

def b64encode(self, assert_valid: bool = True) -> bytes:
"""Return the BMS address-based signature as base64-encoding.
Expand Down
12 changes: 5 additions & 7 deletions btclib/der.py
Expand Up @@ -95,7 +95,7 @@ def _deserialize_scalar(sig_data_stream: BytesIO) -> int:
_Sig = TypeVar("_Sig", bound="Sig")


@dataclass
@dataclass(frozen=True)
class Sig(DataClassJsonMixin):
# 32 bytes
r: int = field(
Expand Down Expand Up @@ -148,7 +148,7 @@ def deserialize(
"""

stream = bytesio_from_binarydata(data)
sig = cls(check_validity=False)
ec = secp256k1

# [0x30] [data-size][0x02][r-size][r][0x02][s-size][s]
marker = stream.read(1)
Expand All @@ -162,15 +162,13 @@ def deserialize(

# [0x02][r-size][r][0x02][s-size][s]
sig_data_substream = bytesio_from_binarydata(sig_data)
sig.r = _deserialize_scalar(sig_data_substream)
sig.s = _deserialize_scalar(sig_data_substream)
r = _deserialize_scalar(sig_data_substream)
s = _deserialize_scalar(sig_data_substream)

# to prevent malleability
# the sig_data_substream must have been consumed entirely
if sig_data_substream.read(1) != b"":
err_msg = "invalid DER sequence length"
raise BTClibValueError(err_msg)

if assert_valid:
sig.assert_valid()
return sig
return cls(r, s, ec, check_validity=assert_valid)
14 changes: 5 additions & 9 deletions btclib/ssa.py
Expand Up @@ -75,7 +75,7 @@
_Sig = TypeVar("_Sig", bound="Sig")


@dataclass
@dataclass(frozen=True)
class Sig(DataClassJsonMixin):
"""BIP340-Schnorr signature.
Expand Down Expand Up @@ -127,14 +127,10 @@ def deserialize(
) -> _Sig:

stream = bytesio_from_binarydata(data)
sig = cls(check_validity=False)

sig.r = int.from_bytes(stream.read(sig.ec.psize), byteorder="big", signed=False)
sig.s = int.from_bytes(stream.read(sig.ec.nsize), byteorder="big", signed=False)

if assert_valid:
sig.assert_valid()
return sig
ec = secp256k1
r = int.from_bytes(stream.read(ec.psize), byteorder="big", signed=False)
s = int.from_bytes(stream.read(ec.nsize), byteorder="big", signed=False)
return cls(r, s, ec, check_validity=assert_valid)


# hex-string or bytes representation of an int
Expand Down
39 changes: 22 additions & 17 deletions btclib/tests/test_bms.py
Expand Up @@ -52,10 +52,10 @@ def test_signature() -> None:
bms.assert_as_valid(msg, addr, exp_sig)
bms.assert_as_valid(msg, addr, exp_sig.decode())

sig.dsa_sig.ec = CURVES["secp256r1"]
dsa_sig = dsa.Sig(sig.dsa_sig.r, sig.dsa_sig.s, CURVES["secp256r1"])
err_msg = "invalid curve: "
with pytest.raises(BTClibValueError, match=err_msg):
bms.assert_as_valid(msg, addr, sig)
sig = bms.Sig(sig.rf, dsa_sig)


def test_exceptions() -> None:
Expand All @@ -67,10 +67,9 @@ def test_exceptions() -> None:
bms.assert_as_valid(msg, address, exp_sig)

bms_sig = bms.Sig.b64decode(exp_sig)
bms_sig.rf = 26
err_msg = "invalid recovery flag: "
with pytest.raises(BTClibValueError, match=err_msg):
bms_sig.serialize()
bms.Sig(26, bms_sig.dsa_sig)

exp_sig = "IHdKsFF1bUrapA8GMoQUbgI+Ad0ZXyX1c/yAZHmJn5hNBi7J+TrI1615FG3g9JEOPGVvcfDWIFWrg2exLoVc="
err_msg = "invalid decoded length: "
Expand Down Expand Up @@ -121,7 +120,7 @@ def test_exceptions() -> None:
# Invalid recovery flag (39) for base58 address
exp_sig = "IHdKsFF1bUrapA8GMoQUbgI+Ad0ZXyX1c/yAZHmJn5hSNBi7J+TrI1615FG3g9JEOPGVvcfDWIFWrg2exLNtoVc="
bms_sig = bms.Sig.b64decode(exp_sig)
bms_sig.rf = 39
bms_sig = bms.Sig(39, bms_sig.dsa_sig, check_validity=False)
sig_encoded = bms_sig.b64encode(assert_valid=False)
err_msg = "invalid recovery flag: "
with pytest.raises(BTClibValueError, match=err_msg):
Expand All @@ -130,7 +129,7 @@ def test_exceptions() -> None:
# Invalid recovery flag (35) for bech32 address
exp_sig = "IBFyn+h9m3pWYbB4fBFKlRzBD4eJKojgCIZSNdhLKKHPSV2/WkeV7R7IOI0dpo3uGAEpCz9eepXLrA5kF35MXuU="
bms_sig = bms.Sig.b64decode(exp_sig)
bms_sig.rf = 35
bms_sig = bms.Sig(35, bms_sig.dsa_sig, check_validity=False)
err_msg = "invalid recovery flag: "
with pytest.raises(BTClibValueError, match=err_msg):
bms.assert_as_valid(msg, b58_p2wpkh, bms_sig)
Expand Down Expand Up @@ -290,15 +289,17 @@ def test_msgsign_p2pkh() -> None:
assert not bms.verify(msg, add1c, sig1u)
assert not bms.verify(msg, add1u, sig1c)

sig1c.rf += 1 # change rf
assert not bms.verify(msg, add1c, sig1c)
sig1c.rf -= 1 # restore rf
bms_sig = bms.Sig(sig1c.rf + 1, sig1c.dsa_sig)
assert not bms.verify(msg, add1c, bms_sig)

sig1c.dsa_sig.s = ec.n - sig1c.dsa_sig.s # malleate s
assert not bms.verify(msg, add1c, sig1c)
# malleate s
dsa_sig = dsa.Sig(sig1c.dsa_sig.r, ec.n - sig1c.dsa_sig.s, sig1c.dsa_sig.ec)
bms_sig = bms.Sig(sig1c.rf, dsa_sig)
assert not bms.verify(msg, add1c, bms_sig)

sig1c.rf += 1 # update rf to satisfy above malleation
assert bms.verify(msg, add1c, sig1c)
# update rf to satisfy above malleation
bms_sig = bms.Sig(sig1c.rf + 1, dsa_sig)
assert bms.verify(msg, add1c, bms_sig)


def test_msgsign_p2pkh_2() -> None:
Expand Down Expand Up @@ -530,10 +531,14 @@ def test_vector_python_bitcoinlib() -> None:
# self.assertGreater(test_vector_sig.dsa_sig.s, ec.n - test_vector_sig.dsa_sig.s)

# just in case you wonder, here's the malleated signature
bsm_sig.rf += 1 if bsm_sig.rf == 31 else -1
bsm_sig.dsa_sig.s = ec.n - bsm_sig.dsa_sig.s
assert bms.verify(msg, vector["address"], bsm_sig)
bsm_sig_encoded = bsm_sig.b64encode()
dsa_sig = dsa.Sig(
bsm_sig.dsa_sig.r, ec.n - bsm_sig.dsa_sig.s, bsm_sig.dsa_sig.ec
)
bms_sig_malleated = bms.Sig(
bsm_sig.rf + (1 if bsm_sig.rf == 31 else -1), dsa_sig
)
assert bms.verify(msg, vector["address"], bms_sig_malleated)
bsm_sig_encoded = bms_sig_malleated.b64encode()
assert bms.verify(msg, vector["address"], bsm_sig_encoded)
# of course,
# it is not equal to the python-bitcoinlib one (different r)
Expand Down
2 changes: 1 addition & 1 deletion btclib/tests/test_dataclasses_json_bug.py
Expand Up @@ -16,7 +16,7 @@
from dataclasses_json import DataClassJsonMixin


def test_dataclasses_json_bu() -> None:
def test_dataclasses_json_bug() -> None:
@dataclass
class Person(DataClassJsonMixin):
name: str
Expand Down
4 changes: 2 additions & 2 deletions btclib/tests/test_dsa.py
Expand Up @@ -236,9 +236,9 @@ def test_crack_prv_key() -> None:
with pytest.raises(BTClibValueError, match="identical signatures"):
dsa.crack_prv_key(msg1, sig1, msg1, sig1)

sig1.ec = CURVES["secp256r1"]
sig = dsa.Sig(sig1.r, sig1.s, CURVES["secp256r1"])
with pytest.raises(BTClibValueError, match="not the same curve in signatures"):
dsa.crack_prv_key(msg1, sig1, msg2, sig2)
dsa.crack_prv_key(msg1, sig, msg2, sig2)


def test_forge_hash_sig() -> None:
Expand Down
12 changes: 6 additions & 6 deletions btclib/tests/test_ssa.py
Expand Up @@ -247,16 +247,16 @@ def test_crack_prv_key() -> None:
assert q == qc
assert k in (kc, sig1.ec.n - kc)

sig2.r = 16
sig = ssa.Sig(16, sig2.s, sig2.ec)
with pytest.raises(BTClibValueError, match="not the same r in signatures"):
ssa._crack_prv_key(m_1, sig1, m_2, sig2, x_Q)
ssa._crack_prv_key(m_1, sig1, m_2, sig, x_Q)

with pytest.raises(BTClibValueError, match="identical signatures"):
ssa._crack_prv_key(m_1, sig1, m_1, sig1, x_Q)

sig1.ec = CURVES["secp256r1"]
sig = ssa.Sig(sig1.r, sig1.s, CURVES["secp256r1"])
with pytest.raises(BTClibValueError, match="not the same curve in signatures"):
ssa._crack_prv_key(m_1, sig1, m_2, sig2, x_Q)
ssa._crack_prv_key(m_1, sig, m_2, sig2, x_Q)


def test_batch_validation() -> None:
Expand Down Expand Up @@ -310,12 +310,12 @@ def test_batch_validation() -> None:
assert not ssa.batch_verify(ms, Qs, sigs)
sigs.pop() # valid again

sigs[0].ec = CURVES["secp256r1"]
sigs[0] = ssa.Sig(sigs[0].r, sigs[0].s, CURVES["secp256r1"]) # different curve
err_msg = "not the same curve for all signatures"
with pytest.raises(BTClibValueError, match=err_msg):
ssa.assert_batch_as_valid(ms, Qs, sigs)
assert not ssa.batch_verify(ms, Qs, sigs)
sigs[0].ec = CURVES["secp256k1"] # valid again
sigs[0] = ssa.Sig(sigs[0].r, sigs[0].s, CURVES["secp256k1"]) # same curve again

ms = [reduce_to_hlen(m, hf) for m in ms]
ms[0] = ms[0][:-1]
Expand Down
15 changes: 6 additions & 9 deletions btclib/tx.py
Expand Up @@ -222,9 +222,8 @@ def deserialize(cls: Type[_Tx], data: BinaryData, assert_valid: bool = True) ->
"Return a Tx by parsing binary data."

stream = bytesio_from_binarydata(data)
tx = cls(check_validity=False)

tx.version = int.from_bytes(stream.read(4), byteorder="little", signed=False)
version = int.from_bytes(stream.read(4), byteorder="little", signed=False)

segwit = stream.read(2) == _SEGWIT_MARKER
if not segwit:
Expand All @@ -234,17 +233,15 @@ def deserialize(cls: Type[_Tx], data: BinaryData, assert_valid: bool = True) ->
stream.seek(-2, whence)

n = var_int.deserialize(stream)
tx.vin = [TxIn.deserialize(stream) for _ in range(n)]
vin = [TxIn.deserialize(stream) for _ in range(n)]

n = var_int.deserialize(stream)
tx.vout = [TxOut.deserialize(stream) for _ in range(n)]
vout = [TxOut.deserialize(stream) for _ in range(n)]

if segwit:
for tx_in in tx.vin:
for tx_in in vin:
tx_in.script_witness = Witness.deserialize(stream, assert_valid)

tx.lock_time = int.from_bytes(stream.read(4), byteorder="little", signed=False)
lock_time = int.from_bytes(stream.read(4), byteorder="little", signed=False)

if assert_valid:
tx.assert_valid()
return tx
return cls(version, lock_time, vin, vout, check_validity=assert_valid)
22 changes: 7 additions & 15 deletions btclib/tx_in.py
Expand Up @@ -80,14 +80,10 @@ def deserialize(
"Return an OutPoint from the first 36 bytes of the provided data."

data = bytesio_from_binarydata(data)
tx_id = data.read(32)[::-1]
vout = int.from_bytes(data.read(4), "little", signed=False)

outpoint = cls(check_validity=False)
outpoint.tx_id = data.read(32)[::-1]
outpoint.vout = int.from_bytes(data.read(4), "little", signed=False)

if assert_valid:
outpoint.assert_valid()
return outpoint
return cls(tx_id, vout, check_validity=assert_valid)


_TxIn = TypeVar("_TxIn", bound="TxIn")
Expand Down Expand Up @@ -180,12 +176,8 @@ def deserialize(
) -> _TxIn:

s = bytesio_from_binarydata(data)
prev_out = OutPoint.deserialize(s)
script_sig = var_bytes.deserialize(s)
sequence = int.from_bytes(s.read(4), byteorder="little", signed=False)

tx_in = cls(check_validity=False)
tx_in.prev_out = OutPoint.deserialize(s)
tx_in.script_sig = var_bytes.deserialize(s)
tx_in.sequence = int.from_bytes(s.read(4), byteorder="little", signed=False)

if assert_valid:
tx_in.assert_valid()
return tx_in
return cls(prev_out, script_sig, sequence, check_validity=assert_valid)
10 changes: 3 additions & 7 deletions btclib/tx_out.py
Expand Up @@ -130,13 +130,9 @@ def deserialize(
cls: Type[_TxOut], data: BinaryData, assert_valid: bool = True
) -> _TxOut:
stream = bytesio_from_binarydata(data)
tx_out = cls(check_validity=False)
tx_out.value = int.from_bytes(stream.read(8), byteorder="little", signed=False)
tx_out.script_pub_key = var_bytes.deserialize(stream)

if assert_valid:
tx_out.assert_valid()
return tx_out
value = int.from_bytes(stream.read(8), byteorder="little", signed=False)
script_pub_key = var_bytes.deserialize(stream)
return cls(value, script_pub_key, check_validity=assert_valid)

@classmethod
def from_address(cls: Type[_TxOut], value: int, address: String) -> _TxOut:
Expand Down
9 changes: 2 additions & 7 deletions btclib/witness.py
Expand Up @@ -61,11 +61,6 @@ def deserialize(
"Return a Witness by parsing binary data."

data = bytesio_from_binarydata(data)
witness = cls(check_validity=False)

n = var_int.deserialize(data)
witness.stack = [var_bytes.deserialize(data) for _ in range(n)]

if assert_valid:
witness.assert_valid()
return witness
stack = [var_bytes.deserialize(data) for _ in range(n)]
return cls(stack, check_validity=assert_valid)

0 comments on commit 9e123a7

Please sign in to comment.