Skip to content

Commit

Permalink
HNSW bug fixes (#230)
Browse files Browse the repository at this point in the history
Fix bug with HNSW.copy().
  • Loading branch information
ekzhu committed Oct 2, 2023
1 parent c98c145 commit 1ce3f69
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 31 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ benchmark/**/*.pdf

# Virtual env
.venv

# IDE
.vscode

# MacOS
.DS_Store
43 changes: 14 additions & 29 deletions datasketch/hnsw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations
from collections import OrderedDict
import heapq
from itertools import dropwhile
from typing import (
Hashable,
Callable,
Expand Down Expand Up @@ -59,7 +58,7 @@ def __iter__(self) -> Iterable[Hashable]:
def copy(self) -> _Layer:
"""Create a copy of the layer."""
new_layer = _Layer(None)
new_layer._graph = {k: v.copy() for k, v in self._graph.items()}
new_layer._graph = {k: dict(v) for k, v in self._graph.items()}
return new_layer

def get_reverse_edges(self, key: Hashable) -> Set[Hashable]:
Expand Down Expand Up @@ -91,6 +90,8 @@ def __setitem__(self, key: Hashable, value: Dict[Hashable, float]) -> None:
self._reverse_edges[neighbor].discard(key)
for neighbor in value:
self._reverse_edges.setdefault(neighbor, set()).add(key)
if key not in self._reverse_edges:
self._reverse_edges[key] = set()

def __delitem__(self, key: Hashable) -> None:
old_neighbors = self._graph.get(key, {})
Expand All @@ -115,8 +116,8 @@ def __iter__(self) -> Iterable[Hashable]:
def copy(self) -> _LayerWithReversedEdges:
"""Create a copy of the layer."""
new_layer = _LayerWithReversedEdges(None)
new_layer._graph = {k: v.copy() for k, v in self._graph.items()}
new_layer._reverse_edges = self._reverse_edges.copy()
new_layer._graph = {k: dict(v) for k, v in self._graph.items()}
new_layer._reverse_edges = {k: set(v) for k, v in self._reverse_edges.items()}
return new_layer

def get_reverse_edges(self, key: Hashable) -> Set[Hashable]:
Expand Down Expand Up @@ -169,6 +170,9 @@ class HNSW(MutableMapping):
the 0th level. If None, defaults to 2 * m.
seed (Optional[int]): The random seed to use for the random number
generator.
reverse_edges (bool): Whether to maintain reverse edges in the graph.
This speeds up hard remove (:meth:`remove`) but increases memory
usage and slows down :meth:`insert`.
Examples:
Expand Down Expand Up @@ -400,7 +404,9 @@ def copy(self) -> HNSW:
ef_construction=self._ef_construction,
m0=self._m0,
)
new_index._nodes = self._nodes.copy()
new_index._nodes = OrderedDict(
(key, node.copy()) for key, node in self._nodes.items()
)
new_index._graphs = [layer.copy() for layer in self._graphs]
new_index._entry_point = self._entry_point
new_index._random.set_state(self._random.get_state())
Expand Down Expand Up @@ -608,6 +614,7 @@ def _repair_connections(
entry_point,
entry_point_dist,
layer,
# We allow soft-deleted points to be returned and used as entry point.
allow_soft_deleted=True,
key_to_hard_delete=key_to_delete,
)
Expand All @@ -620,6 +627,8 @@ def _repair_connections(
entry_points,
layer,
ef + 1, # We add 1 to ef to account for the point itself.
# We allow soft-deleted points to be returned and used as entry point
# and neighbor candidates.
allow_soft_deleted=True,
key_to_hard_delete=key_to_delete,
)
Expand Down Expand Up @@ -1045,27 +1054,3 @@ def merge(self, other: HNSW) -> HNSW:
new_index = self.copy()
new_index.update(other)
return new_index

def get_non_reachable_keys(self, ef: Optional[int] = None) -> List[Hashable]:
"""Return a list of keys of points that are not reachable from the entry
point using the given ``ef`` value.
Args:
ef (Optional[int]): The number of neighbors to consider during
search. If None, use the construction ef.
Returns:
List[Hashable]: A list of keys of points that are not reachable.
"""
if ef is None:
ef = self._ef_construction
non_reachable = []
if self._entry_point is None:
return non_reachable
for key, node in self._nodes.items():
if node.is_deleted:
continue
neighbors = self.query(node.point, ef=ef)
if key not in [k for k, _ in neighbors]:
non_reachable.append(key)
return non_reachable
2 changes: 1 addition & 1 deletion datasketch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.6.3"
__version__ = "1.6.4"
5 changes: 4 additions & 1 deletion test/test_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def test_copy(self):
hnsw2 = hnsw.copy()
self.assertEqual(hnsw, hnsw2)

hnsw.remove(0)
self.assertTrue(0 not in hnsw)
self.assertTrue(0 in hnsw2)

def test_soft_remove_and_pop_and_clean(self):
data = self._create_random_points()
hnsw = self._create_index(data)
Expand Down Expand Up @@ -162,7 +166,6 @@ def test_soft_remove_and_pop_and_clean(self):
"Potential graph connectivity issue."
)
# NOTE: we are not getting the expected number of results.
# This may be because the graph is not connected anymore.
# Try hard remove all previous soft removed points.
hnsw.clean()
results = hnsw.query(data[i], 10)
Expand Down

0 comments on commit 1ce3f69

Please sign in to comment.