Skip to content

Commit

Permalink
Partial refactoring in preparation for code sharing between Keepkey a…
Browse files Browse the repository at this point in the history
…nd Trezor, but it turns out that Keepkey observe BIP143 which signs the spent output value, so it doesn't need the previous transactions as Trezor does.
  • Loading branch information
rt121212121 committed Nov 5, 2020
1 parent 2c51e63 commit b7f2d6d
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 61 deletions.
1 change: 0 additions & 1 deletion contrib/requirements/requirements-hw.txt
@@ -1,4 +1,3 @@
Cython>=0.27
trezor==0.12.0
keepkey==6.1.0
btchip-python==0.1.31
8 changes: 3 additions & 5 deletions electrumsv/devices/keepkey/client.py
Expand Up @@ -29,7 +29,7 @@
pack_be_uint32,
)

from keepkeylib.client import proto, BaseClient, ProtocolMixin
from keepkeylib.client import proto, BaseClient, ProtocolMixin, types

from electrumsv.exceptions import UserCancelled
from electrumsv.i18n import _
Expand All @@ -50,7 +50,6 @@ def __init__(self, transport, handler, plugin):
self.device = plugin.device
self.handler = handler
self.tx_api = plugin
self.types = plugin.types
self.msg = None
self.creating_wallet = False
self.used()
Expand Down Expand Up @@ -207,9 +206,8 @@ def callback_Failure(self, msg):
# However, making the user acknowledge they cancelled
# gets old very quickly, so we suppress those. The NotInitialized
# one is misnamed and indicates a passphrase request was cancelled.
if msg.code in (self.types.Failure_PinCancelled,
self.types.Failure_ActionCancelled,
self.types.Failure_NotInitialized):
if msg.code in (types.Failure_PinCancelled, types.Failure_ActionCancelled,
types.Failure_NotInitialized):
raise UserCancelled()
raise RuntimeError(msg.message)

Expand Down
122 changes: 69 additions & 53 deletions electrumsv/devices/keepkey/keepkey.py
Expand Up @@ -23,7 +23,7 @@
# SOFTWARE.

import threading
from typing import cast, Dict, Optional
from typing import cast, Dict, List, Optional

from bitcoinx import (
BIP32PublicKey, BIP32Derivation, bip32_decompose_chain_string, Address,
Expand All @@ -39,6 +39,19 @@
from electrumsv.transaction import classify_tx_output, Transaction
from electrumsv.wallet import AbstractAccount

logger = logs.get_logger("plugin.keepkey")

try:
from .client import KeepKeyClient
import keepkeylib
import keepkeylib.ckd_public
from keepkeylib.client import types
from usb1 import USBContext
KEEPKEYLIB = True
except Exception:
logger.exception("Failed to import keepkeylib")
KEEPKEYLIB = False

from ..hw_wallet import HW_PluginBase

# TREZOR initialization methods
Expand All @@ -51,9 +64,14 @@ class KeepKey_KeyStore(Hardware_KeyStore):
hw_type = 'keepkey'
device = KEEPKEY_PRODUCT_KEY

def get_derivation(self):
def get_derivation(self) -> str:
return self.derivation

def requires_input_transactions(self) -> bool:
# Keepkey has a 'tx_api' which is called to retrieve previous transactions, but it is
# not called for BSV coins as they use BIP143, where the spent output's value is signed.
return False

def get_client(self, force_pair=True):
return self.plugin.get_client(self, force_pair)

Expand All @@ -72,8 +90,10 @@ def sign_transaction(self, tx: Transaction, password: str,
prev_txs: Optional[Dict[bytes, Transaction]]=None) -> None:
if tx.is_complete():
return

assert prev_txs is None, "This keystore does not require input transactions"
# path of the xpubs that are involved
xpub_path = {}
xpub_path: Dict[str, str] = {}
for txin in tx.inputs:
for x_pubkey in txin.x_pubkeys:
if not x_pubkey.is_bip32_key():
Expand All @@ -95,24 +115,18 @@ class KeepKeyPlugin(HW_PluginBase):
minimum_firmware = (4, 0, 0)
keystore_class = KeepKey_KeyStore

DEVICE_IDS = (KEEPKEY_PRODUCT_KEY,)

def __init__(self, name):
super().__init__(name)
try:
from . import client
import keepkeylib
import keepkeylib.ckd_public
from usb1 import USBContext
self.client_class = client.KeepKeyClient
self.ckd_public = keepkeylib.ckd_public
self.types = keepkeylib.client.types
self.DEVICE_IDS = (KEEPKEY_PRODUCT_KEY,)
self.usb_context = USBContext()
self.usb_context.open()
self.libraries_available = True
except ImportError:
self.libraries_available = False

self.logger = logs.get_logger("plugin.keepkey")

self.libraries_available = KEEPKEYLIB
if KEEPKEYLIB:
try:
self.usb_context = USBContext()
self.usb_context.open()
except Exception:
self.libraries_available = False

self.main_thread = threading.current_thread()

Expand All @@ -130,19 +144,19 @@ def _libusb_enumerate(self):
yield dev

def _enumerate_hid(self):
if self.libraries_available:
if KEEPKEYLIB:
from keepkeylib.transport_hid import HidTransport
return HidTransport.enumerate()
return []

def _enumerate_web_usb(self):
if self.libraries_available:
if KEEPKEYLIB:
from keepkeylib.transport_webusb import WebUsbTransport
return self._libusb_enumerate()
return []

def _get_transport(self, device):
self.logger.debug("Trying to connect over USB...")
logger.debug("Trying to connect over USB...")

if device.path.startswith('web_usb'):
for d in self._enumerate_web_usb():
Expand Down Expand Up @@ -187,25 +201,25 @@ def create_client(self, device, handler):
try:
transport = self._get_transport(device)
except Exception as e:
self.logger.error("cannot connect to device")
logger.error("cannot connect to device")
raise

self.logger.debug("connected to device at %s", device.path)
logger.debug("connected to device at %s", device.path)

client = self.client_class(transport, handler, self)
client = KeepKeyClient(transport, handler, self)

# Try a ping for device sanity
try:
client.ping('t')
except Exception as e:
self.logger.error("ping failed %s", e)
logger.error("ping failed %s", e)
return None

if not client.atleast_version(*self.minimum_firmware):
msg = (_('Outdated {} firmware for device labelled {}. Please '
'download the updated firmware from {}')
.format(self.device, client.label(), self.firmware_URL))
self.logger.error(msg)
logger.error(msg)
handler.show_error(msg)
return None

Expand Down Expand Up @@ -298,10 +312,10 @@ def get_master_public_key(self, device_id, derivation, wizard):
client.handler = self.create_handler(wizard)
return client.get_master_public_key(derivation)

def sign_transaction(self, keystore, tx, xpub_path):
self.xpub_path = xpub_path
def sign_transaction(self, keystore: KeepKey_KeyStore, tx: Transaction,
xpub_path: Dict[str, str]) -> None:
client = self.get_client(keystore)
inputs = self.tx_inputs(tx)
inputs = self.tx_inputs(tx, xpub_path)
outputs = self.tx_outputs(keystore, keystore.get_derivation(), tx)
signatures = client.sign_tx(self.get_coin_name(client), inputs, outputs,
lock_time=tx.locktime)[0]
Expand All @@ -315,22 +329,23 @@ def show_key(self, account: AbstractAccount, keyinstance_id: int) -> None:
subpath = '/'.join(str(x) for x in derivation_path)
address_path = f"{keystore.derivation}/{subpath}"
address_n = bip32_decompose_chain_string(address_path)
script_type = self.types.SPENDADDRESS
script_type = types.SPENDADDRESS
client.get_address(Net.KEEPKEY_DISPLAY_COIN_NAME, address_n,
True, script_type=script_type)

def tx_inputs(self, tx):
def tx_inputs(self, tx: Transaction, xpub_path: Dict[str, str]) -> List[types.TxInputType]:
inputs = []
for txin in tx.inputs:
txinputtype = self.types.TxInputType()
txinputtype = types.TxInputType()

x_pubkeys = txin.x_pubkeys
if len(x_pubkeys) == 1:
x_pubkey = x_pubkeys[0]
xpub, path = x_pubkey.bip32_extended_key_and_path()
xpub_n = tuple(bip32_decompose_chain_string(self.xpub_path[xpub]))
txinputtype.address_n.extend(xpub_n + path)
txinputtype.script_type = self.types.SPENDADDRESS
xpub_n = bip32_decompose_chain_string(xpub_path[xpub])
txinputtype.address_n.extend(xpub_n)
txinputtype.address_n.extend(path)
txinputtype.script_type = types.SPENDADDRESS
else:
def f(x_pubkey):
if x_pubkey.is_bip32_key():
Expand All @@ -339,26 +354,27 @@ def f(x_pubkey):
xpub = BIP32PublicKey(x_pubkey.to_public_key(), NULL_DERIVATION, Net.COIN)
xpub = xpub.to_extended_key_string()
path = []
node = self.ckd_public.deserialize(xpub)
return self.types.HDNodePathType(node=node, address_n=path)
node = keepkeylib.ckd_public.deserialize(xpub)
return types.HDNodePathType(node=node, address_n=path)
pubkeys = [f(x) for x in x_pubkeys]
multisig = self.types.MultisigRedeemScriptType(
multisig = types.MultisigRedeemScriptType(
pubkeys=pubkeys,
signatures=txin.stripped_signatures_with_blanks(),
m=txin.threshold,
)
script_type = self.types.SPENDMULTISIG
txinputtype = self.types.TxInputType(
script_type = types.SPENDMULTISIG
txinputtype = types.TxInputType(
script_type=script_type,
multisig=multisig
)
# find which key is mine
for x_pubkey in x_pubkeys:
if x_pubkey.is_bip32_key():
xpub, path = x_pubkey.bip32_extended_key_and_path()
if xpub in self.xpub_path:
xpub_n = tuple(bip32_decompose_chain_string(self.xpub_path[xpub]))
txinputtype.address_n.extend(xpub_n + path)
if xpub in xpub_path:
xpub_n = tuple(bip32_decompose_chain_string(xpub_path[xpub]))
txinputtype.address_n.extend(xpub_n)
txinputtype.address_n.extend(path)
break

txinputtype.prev_hash = bytes(reversed(txin.prev_hash))
Expand All @@ -383,32 +399,32 @@ def tx_outputs(self, keystore: KeepKey_KeyStore, derivation: str, tx: Transactio
has_change = True # no more than one change address
key_derivation, xpubs, m = info
if len(xpubs) == 1:
script_type = self.types.PAYTOADDRESS
txoutputtype = self.types.TxOutputType(
script_type = types.PAYTOADDRESS
txoutputtype = types.TxOutputType(
amount = tx_output.value,
script_type = script_type,
address_n = account_derivation + key_derivation,
)
else:
script_type = self.types.PAYTOMULTISIG
nodes = [self.ckd_public.deserialize(xpub) for xpub in xpubs]
pubkeys = [self.types.HDNodePathType(node=node, address_n=key_derivation)
script_type = types.PAYTOMULTISIG
nodes = [keepkeylib.ckd_public.deserialize(xpub) for xpub in xpubs]
pubkeys = [types.HDNodePathType(node=node, address_n=key_derivation)
for node in nodes]
multisig = self.types.MultisigRedeemScriptType(
multisig = types.MultisigRedeemScriptType(
pubkeys = pubkeys,
signatures = [b''] * len(pubkeys),
m = m)
txoutputtype = self.types.TxOutputType(
txoutputtype = types.TxOutputType(
multisig = multisig,
amount = tx_output.value,
address_n = account_derivation + key_derivation,
script_type = script_type)
else:
txoutputtype = self.types.TxOutputType()
txoutputtype = types.TxOutputType()
txoutputtype.amount = tx_output.value
address = classify_tx_output(tx_output)
if isinstance(address, Address):
txoutputtype.script_type = self.types.PAYTOADDRESS
txoutputtype.script_type = types.PAYTOADDRESS
txoutputtype.address = address.to_string()

outputs.append(txoutputtype)
Expand Down
1 change: 1 addition & 0 deletions electrumsv/devices/trezor/trezor.py
Expand Up @@ -76,6 +76,7 @@ def sign_transaction(self, tx: Transaction, password: str,
prev_txs: Optional[Dict[bytes, Transaction]]=None) -> None:
if tx.is_complete():
return

assert prev_txs is not None, "This keystore requires all input transactions"
# path of the xpubs that are involved
xpub_path: Dict[str, str] = {}
Expand Down
2 changes: 0 additions & 2 deletions electrumsv/keystore.py
Expand Up @@ -663,7 +663,6 @@ class Hardware_KeyStore(Xpub, KeyStore):

# Derived classes must set:
# - device
# - DEVICE_IDS
# - wallet_type
hw_type: str
device: str
Expand All @@ -685,7 +684,6 @@ def __init__(self, data: Dict[str, Any], row: Optional[MasterKeyRow]=None) -> No
self.label = data.get('label')
self.handler = None
self.plugin = None
self.libraries_available = False

def clean_up(self) -> None:
app_state.device_manager.unpair_xpub(self.xpub)
Expand Down

0 comments on commit b7f2d6d

Please sign in to comment.