In [10]:
import numpy as np
# from sklearn.metrics.pairwise import euclidean_distances

In [9]:
from datetime import datetime

In [11]:
# NSW
class NSWNode:
    def __init__(self, data):
        self.data = data
        self.neighbors = []

class NSWGraph:
    def __init__(self):
        self.nodes = []

    def add_node(self, data):
        new_node = NSWNode(data)
        for node in self.nodes:
            # Connect new node with existing nodes if they are close
            if self.should_connect(new_node, node):
                new_node.neighbors.append(node)
                node.neighbors.append(new_node)
        self.nodes.append(new_node)

    def should_connect(self, node1, node2, threshold=0.5):
        # This function decides if two nodes should be connected.
        # Threshold can be adjusted based on dataset characteristics.
        distance = np.linalg.norm(node1.data - node2.data)
        return distance < threshold

    def greedy_search(self, query, max_steps=100):
        if not self.nodes:
            return None

        current = np.random.choice(self.nodes)
        for _ in range(max_steps):
            if current is None:
                break
            closest = min(current.neighbors + [current], key=lambda node: np.linalg.norm(node.data - query))
            if closest == current:
                break
            current = closest
        return current

    def beam_search(self, query, beam_width=2, max_steps = 100):
        if not self.nodes:
            return None

        current = np.random.choice(self.nodes)
        for _ in range(max_steps):
            if current is None:
                break
            closest = min(current.neighbors + [current], key=lambda node: np.linalg.norm(node.data - query))
            if closest == current:
                break
            current = closest
        return current

In [15]:
def test():
    def time(graph):
        start = datetime.now()
        for _ in range(100):
            query_point = np.random.rand(10)
            nearest_neighbor = graph.greedy_search(query_point)
        end = datetime.now()
        print(len(graph.nodes))
        print(end - start)
        print()

    def add(graph, node_count):
        for _ in range(node_count):
            graph.add_node(np.random.rand(10))

    nsw1 = NSWGraph()
    nsw2 = NSWGraph()
    nsw3 = NSWGraph()

    add(nsw1, 1000)
    add(nsw2, 2000)
    add(nsw3, 4000)

    time(nsw1)
    time(nsw2)
    time(nsw3)

test()

1000
0:00:00.067083

2000
0:00:00.131408

4000
0:00:00.269464

