Skip to content

Commit

Permalink
Merge pull request #1468 from lnbits/pyright3
Browse files Browse the repository at this point in the history
introduce pyright + fix issues (supersedes #1444)
  • Loading branch information
arcbtc committed Apr 5, 2023
2 parents 3855cf4 + 255638a commit 47df941
Show file tree
Hide file tree
Showing 26 changed files with 384 additions and 335 deletions.
2 changes: 1 addition & 1 deletion Makefile
@@ -1,6 +1,6 @@
.PHONY: test

all: format check requirements.txt
all: format check

format: prettier isort black

Expand Down
43 changes: 23 additions & 20 deletions lnbits/bolt11.py
Expand Up @@ -66,11 +66,12 @@ def decode(pr: str) -> Invoice:
invoice.amount_msat = _unshorten_amount(amountstr)

# pull out date
invoice.date = data.read(35).uint
date_bin = data.read(35)
invoice.date = date_bin.uint # type: ignore

while data.pos != data.len:
tag, tagdata, data = _pull_tagged(data)
data_length = len(tagdata) / 5
data_length = len(tagdata or []) / 5

if tag == "d":
invoice.description = _trim_to_bytes(tagdata).decode()
Expand All @@ -79,7 +80,7 @@ def decode(pr: str) -> Invoice:
elif tag == "p" and data_length == 52:
invoice.payment_hash = _trim_to_bytes(tagdata).hex()
elif tag == "x":
invoice.expiry = tagdata.uint
invoice.expiry = tagdata.uint # type: ignore
elif tag == "n":
invoice.payee = _trim_to_bytes(tagdata).hex()
# this won't work in most cases, we must extract the payee
Expand All @@ -90,11 +91,11 @@ def decode(pr: str) -> Invoice:
s = bitstring.ConstBitStream(tagdata)
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
route = Route(
pubkey=s.read(264).tobytes().hex(),
short_channel_id=_readable_scid(s.read(64).intbe),
base_fee_msat=s.read(32).intbe,
ppm_fee=s.read(32).intbe,
cltv=s.read(16).intbe,
pubkey=s.read(264).tobytes().hex(), # type: ignore
short_channel_id=_readable_scid(s.read(64).intbe), # type: ignore
base_fee_msat=s.read(32).intbe, # type: ignore
ppm_fee=s.read(32).intbe, # type: ignore
cltv=s.read(16).intbe, # type: ignore
)
invoice.route_hints.append(route)

Expand Down Expand Up @@ -202,7 +203,8 @@ def lnencode(addr, privkey):
)
data += tagged("r", route)
elif k == "f":
data += encode_fallback(v, addr.currency)
# NOTE: there was an error fallback here that's now removed
continue
elif k == "d":
data += tagged_bytes("d", v.encode())
elif k == "x":
Expand Down Expand Up @@ -244,19 +246,27 @@ def lnencode(addr, privkey):

class LnAddr:
def __init__(
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None
self,
paymenthash=None,
amount=None,
currency="bc",
tags=None,
date=None,
fallback=None,
):
self.date = int(time.time()) if not date else int(date)
self.tags = [] if not tags else tags
self.unknown_tags = []
self.paymenthash = paymenthash
self.signature = None
self.pubkey = None
self.fallback = fallback
self.currency = currency
self.amount = amount

def __str__(self):
pubkey = bytes.hex(self.pubkey.serialize()).decode()
assert self.pubkey, "LnAddr, pubkey must be set"
pubkey = bytes.hex(self.pubkey.serialize())
tags = ", ".join([f"{k}={v}" for k, v in self.tags])
return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]"

Expand All @@ -266,6 +276,7 @@ def shorten_amount(amount):
# Convert to pico initially
amount = int(amount * 10**12)
units = ["p", "n", "u", "m", ""]
unit = ""
for unit in units:
if amount % 1000 == 0:
amount //= 1000
Expand Down Expand Up @@ -304,14 +315,6 @@ def _pull_tagged(stream):
return (CHARSET[tag], stream.read(length * 5), stream)


def is_p2pkh(currency, prefix):
return prefix == base58_prefix_map[currency][0]


def is_p2sh(currency, prefix):
return prefix == base58_prefix_map[currency][1]


# Tagged field containing BitArray
def tagged(char, l):
# Tagged fields need to be zero-padded to 5 bits.
Expand Down Expand Up @@ -359,5 +362,5 @@ def bitarray_to_u5(barr):
ret = []
s = bitstring.ConstBitStream(barr)
while s.pos != s.len:
ret.append(s.read(5).uint)
ret.append(s.read(5).uint) # type: ignore
return ret
1 change: 1 addition & 0 deletions lnbits/commands.py
Expand Up @@ -41,6 +41,7 @@ async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""

async with core_db.connect() as conn:
exists = False
if conn.type == SQLITE:
exists = await conn.fetchone(
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
Expand Down
17 changes: 12 additions & 5 deletions lnbits/core/crud.py
Expand Up @@ -206,14 +206,17 @@ async def create_wallet(
async def update_wallet(
wallet_id: str, new_name: str, conn: Optional[Connection] = None
) -> Optional[Wallet]:
return await (conn or db).execute(
await (conn or db).execute(
"""
UPDATE wallets SET
name = ?
WHERE id = ?
""",
(new_name, wallet_id),
)
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert wallet, "updated created wallet couldn't be retrieved"
return wallet


async def delete_wallet(
Expand Down Expand Up @@ -393,7 +396,7 @@ async def get_payments(
clause.append("checking_id NOT LIKE 'internal_%'")

if not filters:
filters = Filters()
filters = Filters(limit=None, offset=None)

rows = await (conn or db).fetchall(
f"""
Expand Down Expand Up @@ -712,15 +715,19 @@ async def update_admin_settings(data: EditableSettings):
await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),))


async def update_super_user(super_user: str):
async def update_super_user(super_user: str) -> SuperSettings:
await db.execute("UPDATE settings SET super_user = ?", (super_user,))
return await get_super_settings()
settings = await get_super_settings()
assert settings, "updated super_user settings could not be retrieved"
return settings


async def create_admin_settings(super_user: str, new_settings: dict):
sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
await db.execute(sql, (super_user, json.dumps(new_settings)))
return await get_super_settings()
settings = await get_super_settings()
assert settings, "created admin settings could not be retrieved"
return settings


# db versions
Expand Down
31 changes: 14 additions & 17 deletions lnbits/core/services.py
@@ -1,7 +1,7 @@
import asyncio
import json
from io import BytesIO
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, TypedDict
from urllib.parse import parse_qs, urlparse

import httpx
Expand All @@ -17,6 +17,7 @@
from lnbits.settings import (
FAKE_WALLET,
EditableSettings,
SuperSettings,
get_wallet_class,
readonly_variables,
send_admin_user_to_saas,
Expand All @@ -43,11 +44,6 @@
)
from .models import Payment

try:
from typing import TypedDict
except ImportError: # pragma: nocover
from typing_extensions import TypedDict


class PaymentFailure(Exception):
pass
Expand Down Expand Up @@ -188,7 +184,7 @@ class PaymentKwargs(TypedDict):

# do the balance check
wallet = await get_wallet(wallet_id, conn=conn)
assert wallet
assert wallet, "Wallet for balancecheck could not be fetched"
if wallet.balance_msat < 0:
logger.debug("balance is too low, deleting temporary payment")
if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat:
Expand Down Expand Up @@ -336,36 +332,37 @@ def int_to_bytes_suitable_der(x: int) -> bytes:

return b

def encode_strict_der(r_int, s_int, order):
def encode_strict_der(r: int, s: int, order: int):
# if s > order/2 verification will fail sometimes
# so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147)
if s_int > order // 2:
s_int = order - s_int
if s > order // 2:
s = order - s

# now we do the strict DER encoding copied from
# https://github.com/KiriKiri/bip66 (without any checks)
r = int_to_bytes_suitable_der(r_int)
s = int_to_bytes_suitable_der(s_int)
r_temp = int_to_bytes_suitable_der(r)
s_temp = int_to_bytes_suitable_der(s)

r_len = len(r)
s_len = len(s)
r_len = len(r_temp)
s_len = len(s_temp)
sign_len = 6 + r_len + s_len

signature = BytesIO()
signature.write(0x30.to_bytes(1, "big", signed=False))
signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(r_len.to_bytes(1, "big", signed=False))
signature.write(r)
signature.write(r_temp)
signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(s_len.to_bytes(1, "big", signed=False))
signature.write(s)
signature.write(s_temp)

return signature.getvalue()

sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)

async with httpx.AsyncClient() as client:
assert key.verifying_key, "LNURLauth verifying_key does not exist"
r = await client.get(
callback,
params={
Expand Down Expand Up @@ -469,7 +466,7 @@ def update_cached_settings(sets_dict: dict):
setattr(settings, "super_user", sets_dict["super_user"])


async def init_admin_settings(super_user: str = None):
async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings:
account = None
if super_user:
account = await get_account(super_user)
Expand Down
16 changes: 10 additions & 6 deletions lnbits/core/views/api.py
Expand Up @@ -411,8 +411,7 @@ async def payment_received() -> None:
typ, data = await send_queue.get()
if data:
jdata = json.dumps(dict(data.dict(), pending=False))

yield dict(data=jdata, event=typ)
yield dict(data=jdata, event=typ)
except asyncio.CancelledError:
logger.debug(f"removing listener for wallet {uid}")
api_invoice_listeners.pop(uid)
Expand All @@ -431,11 +430,12 @@ async def api_payments_sse(
)


# TODO: refactor this route into a public and admin one
@core_app.get("/api/v1/payments/{payment_hash}")
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
# We use X_Api_Key here because we want this call to work with and without keys
# If a valid key is given, we also return the field "details", otherwise not
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None # type: ignore

# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results
Expand Down Expand Up @@ -505,6 +505,7 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type
params.update(callback=url) # with k1 already in it

lnurlauth_key = wallet.wallet.lnurlauth_key(domain)
assert lnurlauth_key.verifying_key
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
else:
async with httpx.AsyncClient() as client:
Expand Down Expand Up @@ -693,7 +694,7 @@ async def api_auditor():
if not error_message:
delta = node_balance - total_balance
else:
node_balance, delta = None, None
node_balance, delta = 0, 0

return {
"node_balance_msats": int(node_balance),
Expand Down Expand Up @@ -745,6 +746,7 @@ async def api_install_extension(
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Release not found"
)

ext_info = InstallableExtension(
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon
)
Expand Down Expand Up @@ -824,8 +826,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
)


@core_app.get("/api/v1/extension/{ext_id}/releases")
async def get_extension_releases(ext_id: str, user: User = Depends(check_admin)):
@core_app.get(
"/api/v1/extension/{ext_id}/releases", dependencies=[Depends(check_admin)]
)
async def get_extension_releases(ext_id: str):
try:
extension_releases: List[
ExtensionRelease
Expand Down
9 changes: 4 additions & 5 deletions lnbits/core/views/public_api.py
Expand Up @@ -40,19 +40,18 @@ async def api_public_payment_longpolling(payment_hash):

response = None

async def payment_info_receiver(cancel_scope):
async for payment in payment_queue.get():
async def payment_info_receiver():
for payment in await payment_queue.get():
if payment.payment_hash == payment_hash:
nonlocal response
response = {"status": "paid"}
cancel_scope.cancel()

async def timeouter(cancel_scope):
await asyncio.sleep(45)
cancel_scope.cancel()

asyncio.create_task(payment_info_receiver())
asyncio.create_task(timeouter())
cancel_scope = asyncio.create_task(payment_info_receiver())
asyncio.create_task(timeouter(cancel_scope))

if response:
return response
Expand Down

0 comments on commit 47df941

Please sign in to comment.