Skip to content

Commit

Permalink
Merge #545: Add tr() descriptors for Taproot output scripts
Browse files Browse the repository at this point in the history
d2d360c descriptors: Add PKDescriptor and parsing (Andrew Chow)
c464931 descriptors: Add TRDescriptor, parsing, and tests (Andrew Chow)
d131006 descriptor: Allow multiple subdescriptors (Andrew Chow)
48618bd descriptors: Explicitly list allowed contexts for functions (Andrew Chow)

Pull request description:

  Implements `tr()` descriptors

ACKs for top commit:
  Sjors:
    utACK d2d360c

Tree-SHA512: 68ce1b9507bf6a17a1a203d97839ca6308e039349a7191a04583b1ff3959507b6d1e02888a82a981aff25448fc1f0f4d8acba421cc5dfb541b9836f96481afe5
  • Loading branch information
achow101 committed Dec 15, 2021
2 parents f1949af + d2d360c commit e0a6c5c
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 61 deletions.
8 changes: 4 additions & 4 deletions hwilib/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,12 @@ def displayaddress(
is_sh = isinstance(descriptor, SHDescriptor)
is_wsh = isinstance(descriptor, WSHDescriptor)
if is_sh or is_wsh:
assert descriptor.subdescriptor
descriptor = descriptor.subdescriptor
assert len(descriptor.subdescriptors) == 1
descriptor = descriptor.subdescriptors[0]
if isinstance(descriptor, WSHDescriptor):
is_wsh = True
assert descriptor.subdescriptor
descriptor = descriptor.subdescriptor
assert len(descriptor.subdescriptors) == 1
descriptor = descriptor.subdescriptors[0]
if isinstance(descriptor, MultisigDescriptor):
if is_sh and is_wsh:
addr_type = AddressType.SH_WIT
Expand Down
190 changes: 164 additions & 26 deletions hwilib/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
)


MAX_TAPROOT_NODES = 128


ExpandedScripts = namedtuple("ExpandedScripts", ["output_script", "redeem_script", "witness_script"])

def PolyMod(c: int, val: int) -> int:
Expand Down Expand Up @@ -209,21 +212,21 @@ def __lt__(self, other: 'PubkeyProvider') -> bool:
class Descriptor(object):
r"""
An abstract class for Descriptors themselves.
Descriptors can contain multiple :class:`PubkeyProvider`\ s and no more than one ``Descriptor`` as a subdescriptor.
Descriptors can contain multiple :class:`PubkeyProvider`\ s and multiple ``Descriptor`` as subdescriptors.
"""
def __init__(
self,
pubkeys: List['PubkeyProvider'],
subdescriptor: Optional['Descriptor'],
subdescriptors: List['Descriptor'],
name: str
) -> None:
r"""
:param pubkeys: The :class:`PubkeyProvider`\ s that are part of this descriptor
:param subdescriptor: The ``Descriptor`` that is part of this descriptor
:param subdescriptor: The ``Descriptor``s that are part of this descriptor
:param name: The name of the function for this descriptor
"""
self.pubkeys = pubkeys
self.subdescriptor = subdescriptor
self.subdescriptors = subdescriptors
self.name = name

def to_string_no_checksum(self) -> str:
Expand All @@ -235,7 +238,7 @@ def to_string_no_checksum(self) -> str:
return "{}({}{})".format(
self.name,
",".join([p.to_string() for p in self.pubkeys]),
self.subdescriptor.to_string_no_checksum() if self.subdescriptor else ""
self.subdescriptors[0].to_string_no_checksum() if len(self.subdescriptors) > 0 else ""
)

def to_string(self) -> str:
Expand All @@ -253,6 +256,20 @@ def expand(self, pos: int) -> "ExpandedScripts":
raise NotImplementedError("The Descriptor base class does not implement this method")


class PKDescriptor(Descriptor):
"""
A descriptor for ``pk()`` descriptors
"""
def __init__(
self,
pubkey: 'PubkeyProvider'
) -> None:
"""
:param pubkey: The :class:`PubkeyProvider` for this descriptor
"""
super().__init__([pubkey], [], "pk")


class PKHDescriptor(Descriptor):
"""
A descriptor for ``pkh()`` descriptors
Expand All @@ -264,7 +281,7 @@ def __init__(
"""
:param pubkey: The :class:`PubkeyProvider` for this descriptor
"""
super().__init__([pubkey], None, "pkh")
super().__init__([pubkey], [], "pkh")

def expand(self, pos: int) -> "ExpandedScripts":
script = b"\x76\xa9\x14" + hash160(self.pubkeys[0].get_pubkey_bytes(pos)) + b"\x88\xac"
Expand All @@ -282,7 +299,7 @@ def __init__(
"""
:param pubkey: The :class:`PubkeyProvider` for this descriptor
"""
super().__init__([pubkey], None, "wpkh")
super().__init__([pubkey], [], "wpkh")

def expand(self, pos: int) -> "ExpandedScripts":
script = b"\x00\x14" + hash160(self.pubkeys[0].get_pubkey_bytes(pos))
Expand All @@ -304,7 +321,7 @@ def __init__(
:param thresh: The number of keys required to sign this multisig
:param is_sorted: Whether this is a ``sortedmulti()`` descriptor
"""
super().__init__(pubkeys, None, "sortedmulti" if is_sorted else "multi")
super().__init__(pubkeys, [], "sortedmulti" if is_sorted else "multi")
self.thresh = thresh
self.is_sorted = is_sorted
if self.is_sorted:
Expand Down Expand Up @@ -336,16 +353,16 @@ class SHDescriptor(Descriptor):
"""
def __init__(
self,
subdescriptor: Optional['Descriptor']
subdescriptor: 'Descriptor'
) -> None:
"""
:param subdescriptor: The :class:`Descriptor` that is a sub-descriptor for this descriptor
"""
super().__init__([], subdescriptor, "sh")
super().__init__([], [subdescriptor], "sh")

def expand(self, pos: int) -> "ExpandedScripts":
assert self.subdescriptor
redeem_script, _, witness_script = self.subdescriptor.expand(pos)
assert len(self.subdescriptors) == 1
redeem_script, _, witness_script = self.subdescriptors[0].expand(pos)
script = b"\xa9\x14" + hash160(redeem_script) + b"\x87"
return ExpandedScripts(script, redeem_script, witness_script)

Expand All @@ -356,20 +373,57 @@ class WSHDescriptor(Descriptor):
"""
def __init__(
self,
subdescriptor: Optional['Descriptor']
subdescriptor: 'Descriptor'
) -> None:
"""
:param pubkey: The :class:`Descriptor` that is a sub-descriptor for this descriptor
:param subdescriptor: The :class:`Descriptor` that is a sub-descriptor for this descriptor
"""
super().__init__([], subdescriptor, "wsh")
super().__init__([], [subdescriptor], "wsh")

def expand(self, pos: int) -> "ExpandedScripts":
assert self.subdescriptor
witness_script, _, _ = self.subdescriptor.expand(pos)
assert len(self.subdescriptors) == 1
witness_script, _, _ = self.subdescriptors[0].expand(pos)
script = b"\x00\x20" + sha256(witness_script)
return ExpandedScripts(script, None, witness_script)


class TRDescriptor(Descriptor):
"""
A descriptor for ``tr()`` descriptors
"""
def __init__(
self,
internal_key: 'PubkeyProvider',
subdescriptors: List['Descriptor'] = [],
depths: List[int] = []
) -> None:
"""
:param internal_key: The :class:`PubkeyProvider` that is the internal key for this descriptor
:param subdescriptors: The :class:`Descriptor`s that are the leaf scripts for this descriptor
:param depths: The depths of the leaf scripts in the same order as `subdescriptors`
"""
super().__init__([internal_key], subdescriptors, "tr")
self.depths = depths

def to_string_no_checksum(self) -> str:
r = f"{self.name}({self.pubkeys[0].to_string()}"
path: List[bool] = [] # Track left or right for each depth
for p, depth in enumerate(self.depths):
r += ","
while len(path) <= depth:
if len(path) > 0:
r += "{"
path.append(False)
r += self.subdescriptors[p].to_string_no_checksum()
while len(path) > 0 and path[-1]:
if len(path) > 0:
r += "}"
path.pop()
if len(path) > 0:
path[-1] = True
r += ")"
return r

def _get_func_expr(s: str) -> Tuple[str, str]:
"""
Get the function name and then the expression inside
Expand All @@ -383,6 +437,41 @@ def _get_func_expr(s: str) -> Tuple[str, str]:
return s[0:start], s[start + 1:end]


def _get_const(s: str, const: str) -> str:
"""
Get the first character of the string, make sure it is the expected character,
and return the rest of the string
:param s: The string that begins with a constant character
:param const: The constant character
:return: The remainder of the string without the constant character
:raises: ValueError: if the first character is not the constant character
"""
if s[0] != const:
raise ValueError(f"Expected '{const}' but got '{s[0]}'")
return s[1:]


def _get_expr(s: str) -> Tuple[str, str]:
"""
Extract the expression that ``s`` begins with.
This will return the initial part of ``s``, up to the first comma or closing brace,
skipping ones that are surrounded by braces.
:param s: The string to extract the expression from
:return: A pair with the first item being the extracted expression and the second the rest of the string
"""
level: int = 0
for i, c in enumerate(s):
if c in ["(", "{"]:
level += 1
elif level > 0 and c in [")", "}"]:
level -= 1
elif level == 0 and c in [")", "}", ","]:
break
return s[0:i], s[i:]

def parse_pubkey(expr: str) -> Tuple['PubkeyProvider', str]:
"""
Parses an individual pubkey expression from a string that may contain more than one pubkey expression.
Expand Down Expand Up @@ -416,6 +505,9 @@ class _ParseDescriptorContext(Enum):
P2WSH = 3
"""Within a ``wsh()`` descriptor"""

P2TR = 4
"""Within a ``tr()`` descriptor"""


def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor':
"""
Expand All @@ -430,12 +522,21 @@ def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor'
:raises: ValueError: if the descriptor is malformed
"""
func, expr = _get_func_expr(desc)
if func == "pk":
pubkey, expr = parse_pubkey(expr)
if expr:
raise ValueError("more than one pubkey in pk descriptor")
return PKDescriptor(pubkey)
if func == "pkh":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH or ctx == _ParseDescriptorContext.P2WSH):
raise ValueError("Can only have pkh at top level, in sh(), or in wsh()")
pubkey, expr = parse_pubkey(expr)
if expr:
raise ValueError("More than one pubkey in pkh descriptor")
return PKHDescriptor(pubkey)
if func == "sortedmulti" or func == "multi":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH or ctx == _ParseDescriptorContext.P2WSH):
raise ValueError("Can only have multi/sortedmulti at top level, in sh(), or in wsh()")
is_sorted = func == "sortedmulti"
comma_idx = expr.index(",")
thresh = int(expr[:comma_idx])
Expand All @@ -453,23 +554,60 @@ def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor'
if ctx == _ParseDescriptorContext.TOP and len(pubkeys) > 3:
raise ValueError("Cannot have {} pubkeys in bare multisig: only at most 3 pubkeys")
return MultisigDescriptor(pubkeys, thresh, is_sorted)
if ctx != _ParseDescriptorContext.P2WSH and func == "wpkh":
if func == "wpkh":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH):
raise ValueError("Can only have wpkh() at top level or inside sh()")
pubkey, expr = parse_pubkey(expr)
if expr:
raise ValueError("More than one pubkey in pkh descriptor")
return WPKHDescriptor(pubkey)
elif ctx == _ParseDescriptorContext.P2WSH and func == "wpkh":
raise ValueError("Cannot have wpkh within wsh")
if ctx == _ParseDescriptorContext.TOP and func == "sh":
if func == "sh":
if ctx != _ParseDescriptorContext.TOP:
raise ValueError("Can only have sh() at top level")
subdesc = _parse_descriptor(expr, _ParseDescriptorContext.P2SH)
return SHDescriptor(subdesc)
elif ctx != _ParseDescriptorContext.TOP and func == "sh":
raise ValueError("Cannot have sh in non-top level")
if ctx != _ParseDescriptorContext.P2WSH and func == "wsh":
if func == "wsh":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH):
raise ValueError("Can only have wsh() at top level or inside sh()")
subdesc = _parse_descriptor(expr, _ParseDescriptorContext.P2WSH)
return WSHDescriptor(subdesc)
elif ctx == _ParseDescriptorContext.P2WSH and func == "wsh":
raise ValueError("Cannot have wsh within wsh")
if func == "tr":
if ctx != _ParseDescriptorContext.TOP:
raise ValueError("Can only have tr at top level")
internal_key, expr = parse_pubkey(expr)
subscripts = []
depths = []
if expr:
# Path from top of the tree to what we're currently processing.
# branches[i] == False: left branch in the i'th step from the top
# branches[i] == true: right branch
branches = []
while True:
# Process open braces
while True:
try:
expr = _get_const(expr, "{")
branches.append(False)
except ValueError:
break
if len(branches) > MAX_TAPROOT_NODES:
raise ValueError("tr() suports at most {MAX_TAPROOT_NODES} nesting levels")
# Process script expression
sarg, expr = _get_expr(expr)
subscripts.append(_parse_descriptor(sarg, _ParseDescriptorContext.P2TR))
depths.append(len(branches))
# Process closing braces
while len(branches) > 0 and branches[-1]:
expr = _get_const(expr, "}")
branches.pop()
# If we're at the end of a left branch, expect a comma
if len(branches) > 0 and not branches[-1]:
expr = _get_const(expr, ",")
branches[-1] = True

if len(branches) == 0:
break
return TRDescriptor(internal_key, subscripts, depths)
if ctx == _ParseDescriptorContext.P2SH:
raise ValueError("A function is needed within P2SH")
elif ctx == _ParseDescriptorContext.P2WSH:
Expand Down

0 comments on commit e0a6c5c

Please sign in to comment.