Skip to content

Commit

Permalink
Merge pull request #3450 from lbryio/deterministic_channel_keys
Browse files Browse the repository at this point in the history
deterministic channel keys (requires wallet server re-sync)
  • Loading branch information
eukreign committed Dec 23, 2021
2 parents 5eb95d7 + c80b30f commit 8076000
Show file tree
Hide file tree
Showing 21 changed files with 505 additions and 148 deletions.
22 changes: 11 additions & 11 deletions lbry/extras/daemon/daemon.py
Expand Up @@ -8,7 +8,6 @@
import inspect
import typing
import random
import hashlib
import tracemalloc
from decimal import Decimal
from urllib.parse import urlencode, quote
Expand All @@ -17,7 +16,6 @@
from traceback import format_exc
from functools import wraps, partial

import ecdsa
import base58
from aiohttp import web
from prometheus_client import generate_latest as prom_generate_latest, Gauge, Histogram, Counter
Expand All @@ -29,6 +27,7 @@
)
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies, dict_values_to_lbc
from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPE_NAMES
from lbry.wallet.bip32 import PrivateKey

from lbry import utils
from lbry.conf import Config, Setting, NOT_SET
Expand Down Expand Up @@ -2704,12 +2703,13 @@ async def jsonrpc_channel_create(
name, claim, amount, claim_address, funding_accounts, funding_accounts[0]
)
txo = tx.outputs[0]
await txo.generate_channel_private_key()
txo.set_channel_private_key(
await funding_accounts[0].generate_channel_private_key()
)

await tx.sign(funding_accounts)

if not preview:
account.add_channel_private_key(txo.private_key)
wallet.save()
await self.broadcast_or_release(tx, blocking)
self.component_manager.loop.create_task(self.storage.save_claims([self._old_get_temp_claim_info(
Expand Down Expand Up @@ -2858,7 +2858,9 @@ async def jsonrpc_channel_update(
new_txo = tx.outputs[0]

if new_signing_key:
await new_txo.generate_channel_private_key()
new_txo.set_channel_private_key(
await funding_accounts[0].generate_channel_private_key()
)
else:
new_txo.private_key = old_txo.private_key

Expand All @@ -2867,7 +2869,6 @@ async def jsonrpc_channel_update(
await tx.sign(funding_accounts)

if not preview:
account.add_channel_private_key(new_txo.private_key)
wallet.save()
await self.broadcast_or_release(tx, blocking)
self.component_manager.loop.create_task(self.storage.save_claims([self._old_get_temp_claim_info(
Expand Down Expand Up @@ -3039,7 +3040,7 @@ async def jsonrpc_channel_export(self, channel_id=None, channel_name=None, accou
'channel_id': channel.claim_id,
'holding_address': address,
'holding_public_key': public_key.extended_key_string(),
'signing_private_key': channel.private_key.to_pem().decode()
'signing_private_key': channel.private_key.signing_key.to_pem().decode()
}
return base58.b58encode(json.dumps(export, separators=(',', ':')))

Expand All @@ -3062,15 +3063,14 @@ async def jsonrpc_channel_import(self, channel_data, wallet_id=None):

decoded = base58.b58decode(channel_data)
data = json.loads(decoded)
channel_private_key = ecdsa.SigningKey.from_pem(
data['signing_private_key'], hashfunc=hashlib.sha256
channel_private_key = PrivateKey.from_pem(
self.ledger, data['signing_private_key']
)
public_key_der = channel_private_key.get_verifying_key().to_der()

# check that the holding_address hasn't changed since the export was made
holding_address = data['holding_address']
channels, _, _, _ = await self.ledger.claim_search(
wallet.accounts, public_key_id=self.ledger.public_key_to_address(public_key_der)
wallet.accounts, public_key_id=channel_private_key.address
)
if channels and channels[0].get_address(self.ledger) != holding_address:
holding_address = channels[0].get_address(self.ledger)
Expand Down
4 changes: 2 additions & 2 deletions lbry/extras/daemon/json_response_encoder.py
Expand Up @@ -10,7 +10,7 @@
from lbry.schema.support import Support
from lbry.torrent.torrent_manager import TorrentSource
from lbry.wallet import Wallet, Ledger, Account, Transaction, Output
from lbry.wallet.bip32 import PubKey
from lbry.wallet.bip32 import PublicKey
from lbry.wallet.dewies import dewies_to_lbc
from lbry.stream.managed_stream import ManagedStream

Expand Down Expand Up @@ -138,7 +138,7 @@ def default(self, obj): # pylint: disable=method-hidden,arguments-renamed,too-m
return self.encode_claim(obj)
if isinstance(obj, Support):
return obj.to_dict()
if isinstance(obj, PubKey):
if isinstance(obj, PublicKey):
return obj.extended_key_string()
if isinstance(obj, datetime):
return obj.strftime("%Y%m%dT%H:%M:%S")
Expand Down
11 changes: 9 additions & 2 deletions lbry/schema/claim.py
Expand Up @@ -2,6 +2,9 @@
from typing import List
from binascii import hexlify, unhexlify

from asn1crypto.keys import PublicKeyInfo
from coincurve import PublicKey as cPublicKey

from google.protobuf.json_format import MessageToDict
from google.protobuf.message import DecodeError
from hachoir.core.log import log as hachoir_log
Expand Down Expand Up @@ -346,15 +349,19 @@ def to_dict(self):

@property
def public_key(self) -> str:
return hexlify(self.message.public_key).decode()
return hexlify(self.public_key_bytes).decode()

@public_key.setter
def public_key(self, sd_public_key: str):
self.message.public_key = unhexlify(sd_public_key.encode())

@property
def public_key_bytes(self) -> bytes:
return self.message.public_key
if len(self.message.public_key) == 33:
return self.message.public_key
public_key_info = PublicKeyInfo.load(self.message.public_key)
public_key = cPublicKey(public_key_info.native['public_key'])
return public_key.format(compressed=True)

@public_key_bytes.setter
def public_key_bytes(self, public_key: bytes):
Expand Down
15 changes: 15 additions & 0 deletions lbry/testcase.py
Expand Up @@ -17,8 +17,10 @@
from lbry.wallet import WalletManager, Wallet, Ledger, Account, Transaction
from lbry.conf import Config
from lbry.wallet.util import satoshis_to_coins
from lbry.wallet.dewies import lbc_to_dewies
from lbry.wallet.orchstr8 import Conductor
from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode, HubNode
from lbry.schema.claim import Claim

from lbry.extras.daemon.daemon import Daemon, jsonrpc_dumps_pretty
from lbry.extras.daemon.components import Component, WalletComponent
Expand Down Expand Up @@ -506,6 +508,19 @@ async def confirm_and_render(self, awaitable, confirm, return_tx=False) -> Trans
return self.sout(tx)
return tx

async def create_nondeterministic_channel(self, name, price, pubkey_bytes, daemon=None):
account = (daemon or self.daemon).wallet_manager.default_account
claim_address = await account.receiving.get_or_create_usable_address()
claim = Claim()
claim.channel.public_key_bytes = pubkey_bytes
tx = await Transaction.claim_create(
name, claim, lbc_to_dewies(price),
claim_address, [self.account], self.account
)
await tx.sign([self.account])
await (daemon or self.daemon).broadcast_or_release(tx, False)
return self.sout(tx)

def create_upload_file(self, data, prefix=None, suffix=None):
file_path = tempfile.mktemp(prefix=prefix or "tmp", suffix=suffix or "", dir=self.daemon.conf.upload_dir)
with open(file_path, 'w+b') as file:
Expand Down
2 changes: 1 addition & 1 deletion lbry/wallet/__init__.py
Expand Up @@ -10,7 +10,7 @@
from .manager import WalletManager
from .network import Network
from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic, DeterministicChannelKeyManager
from .transaction import Transaction, Output, Input
from .script import OutputScript, InputScript
from .database import SQLiteMixin, Database
Expand Down
98 changes: 65 additions & 33 deletions lbry/wallet/account.py
Expand Up @@ -9,11 +9,10 @@
from string import hexdigits
from typing import Type, Dict, Tuple, Optional, Any, List

import ecdsa
from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt

from .bip32 import PrivateKey, PubKey, from_extended_key_string
from .bip32 import PrivateKey, PublicKey, KeyPath, from_extended_key_string
from .mnemonic import Mnemonic
from .constants import COIN, TXO_TYPES
from .transaction import Transaction, Input, Output
Expand All @@ -34,6 +33,44 @@ def validate_claim_id(claim_id):
raise Exception("Claim id is not hex encoded")


class DeterministicChannelKeyManager:

def __init__(self, account: 'Account'):
self.account = account
self.last_known = 0
self.cache = {}
self.private_key: Optional[PrivateKey] = None
if account.private_key is not None:
self.private_key = account.private_key.child(KeyPath.CHANNEL)

def maybe_generate_deterministic_key_for_channel(self, txo):
if self.private_key is None:
return
next_private_key = self.private_key.child(self.last_known)
public_key = next_private_key.public_key
public_key_bytes = public_key.pubkey_bytes
if txo.claim.channel.public_key_bytes == public_key_bytes:
self.cache[public_key.address] = next_private_key
self.last_known += 1

async def ensure_cache_primed(self):
if self.private_key is not None:
await self.generate_next_key()

async def generate_next_key(self) -> PrivateKey:
db = self.account.ledger.db
while True:
next_private_key = self.private_key.child(self.last_known)
public_key = next_private_key.public_key
self.cache[public_key.address] = next_private_key
if not await db.is_channel_key_used(self.account, public_key):
return next_private_key
self.last_known += 1

def get_private_key_from_pubkey_hash(self, pubkey_hash) -> PrivateKey:
return self.cache.get(pubkey_hash)


class AddressManager:

name: str
Expand Down Expand Up @@ -79,7 +116,7 @@ def _query_addresses(self, **constraints):
def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError

def get_public_key(self, index: int) -> PubKey:
def get_public_key(self, index: int) -> PublicKey:
raise NotImplementedError

async def get_max_gap(self):
Expand Down Expand Up @@ -119,8 +156,8 @@ def __init__(self, account: 'Account', chain: int, gap: int, maximum_uses_per_ad
@classmethod
def from_dict(cls, account: 'Account', d: dict) -> Tuple[AddressManager, AddressManager]:
return (
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
cls(account, KeyPath.RECEIVE, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
cls(account, KeyPath.CHANGE, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
)

def merge(self, d: dict):
Expand All @@ -133,7 +170,7 @@ def to_dict_instance(self):
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index)

def get_public_key(self, index: int) -> PubKey:
def get_public_key(self, index: int) -> PublicKey:
return self.account.public_key.child(self.chain_number).child(index)

async def get_max_gap(self) -> int:
Expand Down Expand Up @@ -193,7 +230,7 @@ class SingleKey(AddressManager):
@classmethod
def from_dict(cls, account: 'Account', d: dict) \
-> Tuple[AddressManager, AddressManager]:
same_address_manager = cls(account, account.public_key, 0)
same_address_manager = cls(account, account.public_key, KeyPath.RECEIVE)
return same_address_manager, same_address_manager

def to_dict_instance(self):
Expand All @@ -202,7 +239,7 @@ def to_dict_instance(self):
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key

def get_public_key(self, index: int) -> PubKey:
def get_public_key(self, index: int) -> PublicKey:
return self.account.public_key

async def get_max_gap(self) -> int:
Expand All @@ -224,17 +261,14 @@ def get_address_records(self, only_usable: bool = False, **constraints):

class Account:

mnemonic_class = Mnemonic
private_key_class = PrivateKey
public_key_class = PubKey
address_generators: Dict[str, Type[AddressManager]] = {
SingleKey.name: SingleKey,
HierarchicalDeterministic.name: HierarchicalDeterministic,
}

def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
seed: str, private_key_string: str, encrypted: bool,
private_key: Optional[PrivateKey], public_key: PubKey,
private_key: Optional[PrivateKey], public_key: PublicKey,
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
self.ledger = ledger
self.wallet = wallet
Expand All @@ -245,13 +279,14 @@ def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
self.private_key_string = private_key_string
self.init_vectors: Dict[str, bytes] = {}
self.encrypted = encrypted
self.private_key = private_key
self.public_key = public_key
self.private_key: Optional[PrivateKey] = private_key
self.public_key: PublicKey = public_key
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name]
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {am.chain_number: am for am in (self.receiving, self.change)}
self.channel_keys = channel_keys
self.deterministic_channel_keys = DeterministicChannelKeyManager(self)
ledger.add_account(self)
wallet.add_account(self)

Expand All @@ -266,19 +301,19 @@ def generate(cls, ledger: 'Ledger', wallet: 'Wallet',
name: str = None, address_generator: dict = None):
return cls.from_dict(ledger, wallet, {
'name': name,
'seed': cls.mnemonic_class().make_seed(),
'seed': Mnemonic().make_seed(),
'address_generator': address_generator or {}
})

@classmethod
def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str):
return cls.private_key_class.from_seed(
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password or 'lbryum')
return PrivateKey.from_seed(
ledger, Mnemonic.mnemonic_to_seed(seed, password or 'lbryum')
)

@classmethod
def keys_from_dict(cls, ledger: 'Ledger', d: dict) \
-> Tuple[str, Optional[PrivateKey], PubKey]:
-> Tuple[str, Optional[PrivateKey], PublicKey]:
seed = d.get('seed', '')
private_key_string = d.get('private_key', '')
private_key = None
Expand Down Expand Up @@ -449,7 +484,7 @@ def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
return self.address_managers[chain].get_private_key(index)

def get_public_key(self, chain: int, index: int) -> PubKey:
def get_public_key(self, chain: int, index: int) -> PublicKey:
return self.address_managers[chain].get_public_key(index)

def get_balance(self, confirmations=0, include_claims=False, read_only=False, **constraints):
Expand Down Expand Up @@ -520,33 +555,30 @@ async def fund(self, to_account, amount=None, everything=False,

return tx

def add_channel_private_key(self, private_key):
public_key_bytes = private_key.get_verifying_key().to_der()
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode()
async def generate_channel_private_key(self):
return await self.deterministic_channel_keys.generate_next_key()

def add_channel_private_key(self, private_key: PrivateKey):
self.channel_keys[private_key.address] = private_key.to_pem().decode()

async def get_channel_private_key(self, public_key_bytes):
async def get_channel_private_key(self, public_key_bytes) -> PrivateKey:
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
private_key_pem = self.channel_keys.get(channel_pubkey_hash)
if private_key_pem:
return await asyncio.get_event_loop().run_in_executor(
None, ecdsa.SigningKey.from_pem, private_key_pem, sha256
)
return PrivateKey.from_pem(self.ledger, private_key_pem)
return self.deterministic_channel_keys.get_private_key_from_pubkey_hash(channel_pubkey_hash)

async def maybe_migrate_certificates(self):
def to_der(private_key_pem):
return ecdsa.SigningKey.from_pem(private_key_pem, hashfunc=sha256).get_verifying_key().to_der()

if not self.channel_keys:
return
channel_keys = {}
for private_key_pem in self.channel_keys.values():
if not isinstance(private_key_pem, str):
continue
if "-----BEGIN EC PRIVATE KEY-----" not in private_key_pem:
if not private_key_pem.startswith("-----BEGIN"):
continue
public_key_der = await asyncio.get_event_loop().run_in_executor(None, to_der, private_key_pem)
channel_keys[self.ledger.public_key_to_address(public_key_der)] = private_key_pem
private_key = PrivateKey.from_pem(self.ledger, private_key_pem)
channel_keys[private_key.address] = private_key_pem
if self.channel_keys != channel_keys:
self.channel_keys = channel_keys
self.wallet.save()
Expand Down

0 comments on commit 8076000

Please sign in to comment.