Skip to content

Commit

Permalink
Merge pull request #17 from pipermerriam/piper/adjust-schema-for-knn-…
Browse files Browse the repository at this point in the history
…lookups

Add support for proximate queries
  • Loading branch information
pipermerriam committed Nov 23, 2020
2 parents 0540fff + ec0bc89 commit d9b7506
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 12 deletions.
6 changes: 6 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ Constraints
:show-inheritance:


.. autoclass:: eth_enr.constraints.ClosestTo
:members:
:undoc-members:
:show-inheritance:


Exceptions
----------

Expand Down
21 changes: 21 additions & 0 deletions eth_enr/constraints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from eth_typing import NodeID

from eth_enr.abc import ConstraintAPI


Expand Down Expand Up @@ -83,6 +85,25 @@ class HasTCPIPv6Endpoint(ConstraintAPI):
pass


class ClosestTo(ConstraintAPI):
"""
Constrains ENR database queries to return records proximate to a specific `node_id`
.. code-block:: python
>>> enr_db = ...
>>> node_id = ...
>>> from eth_enr.constraints import ClosestTo
>>> for enr in enr_db.query(ClosestTo(node_id)):
... print("ENR: ", enr)
"""

node_id: NodeID

def __init__(self, node_id: NodeID) -> None:
self.node_id = node_id


has_tcp_ipv4_endpoint = HasTCPIPv4Endpoint()
has_tcp_ipv6_endpoint = HasTCPIPv6Endpoint()
has_udp_ipv4_endpoint = HasUDPIPv4Endpoint()
Expand Down
30 changes: 27 additions & 3 deletions eth_enr/query_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sqlite3
from typing import Iterable
from typing import Iterable, Optional

from eth_typing import NodeID
from eth_utils import to_tuple
Expand All @@ -15,6 +15,7 @@
UDP_PORT_ENR_KEY,
)
from eth_enr.constraints import (
ClosestTo,
HasTCPIPv4Endpoint,
HasTCPIPv6Endpoint,
HasUDPIPv4Endpoint,
Expand Down Expand Up @@ -51,10 +52,26 @@ def _get_required_keys(*constraints: ConstraintAPI) -> Iterable[bytes]:
elif isinstance(constraint, HasUDPIPv6Endpoint):
yield IP_V6_ADDRESS_ENR_KEY
yield UDP6_PORT_ENR_KEY
elif isinstance(constraint, ClosestTo):
continue
else:
raise TypeError(f"Unsupported constraint type: {type(constraint)}")


def _get_order_closest_to(*constraints: ConstraintAPI) -> Optional[NodeID]:
closest_to_constraints = tuple(
constraint for constraint in constraints if isinstance(constraint, ClosestTo)
)
if len(closest_to_constraints) == 0:
return None
elif len(closest_to_constraints) == 1:
return closest_to_constraints[0].node_id
else:
raise ValueError(
f"Got multiple ClosestTo constraints: {closest_to_constraints}"
)


class QueryableENRDB(ENRDatabaseAPI):
"""
An implementation of :class:`eth_enr.abc.QueryableENRDatabaseAPI` on top of
Expand Down Expand Up @@ -154,7 +171,7 @@ def query(self, *constraints: ConstraintAPI) -> Iterable[ENRAPI]:
"""
Query the database for records that match the given constraints.
Support constrants:
Support constraints:
- :class:`~eth_enr.constraints.KeyExists`
- :class:`~eth_enr.constraints.HasTCPIPv4Endpoint`
Expand All @@ -166,5 +183,12 @@ def query(self, *constraints: ConstraintAPI) -> Iterable[ENRAPI]:
with the highest sequence number for each node_id.
"""
required_keys = _get_required_keys(*constraints)
for record in query_records(self.connection, required_keys=required_keys):
order_closest_to = _get_order_closest_to(*constraints)

records_iter = query_records(
self.connection,
required_keys=required_keys,
order_closest_to=order_closest_to,
)
for record in records_iter:
yield record.to_enr()
43 changes: 34 additions & 9 deletions eth_enr/sqlite3_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import operator
import sqlite3
from typing import Collection, Iterable, NamedTuple, Tuple, Union
from typing import Collection, Iterable, NamedTuple, Optional, Sequence, Tuple, Union

from eth_typing import NodeID
import rlp
Expand All @@ -16,6 +16,7 @@

RECORD_CREATE_STATEMENT = """CREATE TABLE record (
node_id BLOB NOT NULL,
short_node_id INTEGER NOT NULL,
sequence_number INTEGER NOT NULL,
signature BLOB NOT NULL,
created_at DATETIME NOT NULL,
Expand Down Expand Up @@ -201,16 +202,18 @@ def from_row(
fields=tuple(sorted(fields, key=operator.attrgetter("key"))),
)

def to_database_params(self) -> Tuple[NodeID, int, bytes, str]:
def to_database_params(self) -> Tuple[NodeID, int, int, bytes, str]:
return (
self.node_id,
# The high 64 bits of the node_id for doing proximate queries
int.from_bytes(self.node_id, "big") >> 193,
self.sequence_number,
self.signature,
self.created_at.isoformat(sep=" "),
)


RECORD_INSERT_QUERY = "INSERT INTO record (node_id, sequence_number, signature, created_at) VALUES (?, ?, ?, ?)" # noqa: E501
RECORD_INSERT_QUERY = "INSERT INTO record (node_id, short_node_id, sequence_number, signature, created_at) VALUES (?, ?, ?, ?, ?)" # noqa: E501

FIELD_INSERT_QUERY = (
'INSERT INTO field (node_id, sequence_number, "key", value) VALUES (?, ?, ?, ?)'
Expand Down Expand Up @@ -297,6 +300,12 @@ def delete_record(conn: sqlite3.Connection, node_id: NodeID) -> int:
record.sequence_number = field.sequence_number
{where_statements}
GROUP BY record.node_id
{order_by_statement}
"""


PROXIMATE_ORDER_BY_CLAUSE = """
ORDER BY ((?{PARAM_IDX} | record.short_node_id) - (?{PARAM_IDX} & record.short_node_id))
"""


Expand All @@ -311,22 +320,38 @@ def delete_record(conn: sqlite3.Connection, node_id: NodeID) -> int:


def query_records(
conn: sqlite3.Connection, required_keys: Collection[bytes] = ()
conn: sqlite3.Connection,
required_keys: Sequence[bytes] = (),
order_closest_to: Optional[NodeID] = None,
) -> Iterable[Record]:
num_required_keys = len(required_keys)

if num_required_keys == 0:
query = BASE_QUERY.format(where_statements="")
where_clause = ""
elif num_required_keys == 1:
query = BASE_QUERY.format(where_statements=f"WHERE {EXISTS_CLAUSE}")
where_clause = f"WHERE {EXISTS_CLAUSE}"
else:
query_components = tuple([f"({EXISTS_CLAUSE})"] * num_required_keys)
combined_query_components = " AND ".join(query_components)
query = BASE_QUERY.format(where_statements=f"WHERE {combined_query_components}")
where_clause = f"WHERE {combined_query_components}"

if order_closest_to is None:
order_by_clause = ""
params = tuple(required_keys)
else:
order_by_clause = PROXIMATE_ORDER_BY_CLAUSE.format(
PARAM_IDX=num_required_keys + 1
)
short_node_id = int.from_bytes(order_closest_to, "big") >> 193
params = tuple(required_keys) + (short_node_id,)

query = BASE_QUERY.format(
where_statements=where_clause, order_by_statement=order_by_clause
)

logger.debug("query_records: query=%s params=%r", query, required_keys)
logger.debug("query_records: query=%s params=%r", query, params)

for record_row in conn.execute(query, required_keys):
for record_row in conn.execute(query, params):
node_id, sequence_number, *_ = record_row
field_rows = conn.execute(FIELD_GET_QUERY, (node_id, sequence_number))

Expand Down
30 changes: 30 additions & 0 deletions tests/core/test_enr_db_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
UDP_PORT_ENR_KEY,
)
from eth_enr.constraints import (
ClosestTo,
KeyExists,
has_tcp_ipv4_endpoint,
has_tcp_ipv6_endpoint,
Expand Down Expand Up @@ -196,3 +197,32 @@ def test_query_for_ipv6_endpoint(enr_db, constraint):
enr = enr_results[0]

assert enr == enr_b


def test_query_with_order_by_closest(enr_db):
all_enrs = tuple(ENRFactory() for _ in range(4))

target, *enrs = all_enrs
target_node_id_as_int = int.from_bytes(target.node_id, "big")

def distance_fn(enr):
node_id_as_int = int.from_bytes(enr.node_id, "big")
return target_node_id_as_int ^ node_id_as_int

enrs_by_proximity = tuple(sorted(enrs, key=distance_fn))

for enr in enrs:
enr_db.set_enr(enr)

enrs_closest_to_target = tuple(enr_db.query(ClosestTo(target.node_id)))
assert enrs_closest_to_target == enrs_by_proximity

enrs_closest_to_target_with_ip = tuple(
enr_db.query(ClosestTo(target.node_id), KeyExists(b"ip"))
)
assert enrs_closest_to_target_with_ip == enrs_by_proximity

enrs_closest_to_target_with_ipv4_endpoint = tuple(
enr_db.query(ClosestTo(target.node_id), has_tcp_ipv4_endpoint)
)
assert enrs_closest_to_target_with_ipv4_endpoint == enrs_by_proximity

0 comments on commit d9b7506

Please sign in to comment.