In [13]:
import sys
sys.path.append('..')  
import numpy as np
from tqdm import tqdm
from scipy.spatial import distance
import pickle
from src.graph_class import Graph

In [3]:
# now define a function to read the fvecs file format of Sift1M dataset
def read_fvecs(fp):
    a = np.fromfile(fp, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')

def read_ivecs(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()

In [4]:
def calculate_recall(predicted_neighbors, actual_neighbors):
    total_recall = 0
    
    for pred, actual in zip(predicted_neighbors, actual_neighbors):
        true_positives = len(set(pred) & set(actual))
        possible_positives = len(set(actual))

        recall = true_positives / possible_positives if possible_positives else 0

        total_recall += recall

    average_recall = total_recall / len(actual_neighbors)

    return average_recall

In [6]:
# read in data
# data we will search through

base = read_fvecs('C:/Users/ewang/OneDrive/Desktop/Fall 2023/cos597a-final-project/data/siftsmall/siftsmall_base.fvecs')  # 1M samples
# also get some query vectors to search with
query = read_fvecs('C:/Users/ewang/OneDrive/Desktop/Fall 2023/cos597a-final-project/data/siftsmall/siftsmall_query.fvecs')
# take just one query (there are many in sift_learn.fvecs)
# xq = xq[0].reshape(1, xq.shape[1])

groundtruth = read_ivecs('C:/Users/ewang/OneDrive/Desktop/Fall 2023/cos597a-final-project/data/siftsmall/siftsmall_groundtruth.ivecs')

In [16]:
g1 = Graph("sys", data = base)
k = 10
# selected_rows = np.random.choice(xb.shape[0], round(0.001*xb.shape[0]), replace=False)
# index_factors = xb[selected_rows]
graph = g1.build_nsw_greedy(base, k)

Building Graph:   0%|          | 0/10000 [00:00<?, ?it/s]

Building Graph:   2%|▏         | 245/10000 [00:01<00:55, 176.89it/s]


KeyboardInterrupt: 

In [104]:

with open("graph.pkl", "wb") as f:
    pickle.dump(graph, f)

In [123]:
with open('graph.pkl', 'rb') as f:
    objects = pickle.load(f)

In [154]:
len(query)

100

In [193]:
k = 5
results_greedy = []
results_beam = []
for q in tqdm(query):
  g = [r[1] for r in greedy_search(graph, q, k=k)[0]]
  b = [r[1] for r in beam_search(graph, q, k=k)[0]]
  results_greedy.append(g)
  results_beam.append(b)

100%|██████████| 100/100 [00:51<00:00,  1.93it/s]


In [194]:
true = groundtruth[:, :k]

In [204]:
average_recall = calculate_recall(results_greedy, true)
print(average_recall)

0.9919999999999999


In [206]:
average_recall = calculate_recall(results_beam, true)
print(average_recall)

0.9059999999999995


In [12]:
def test():
    def time(graph):
        start = datetime.now()
        for _ in range(100):
            query_point = np.random.rand(10)
            nearest_neighbor = graph.greedy_search(query_point)
        end = datetime.now()
        print(len(graph.nodes))
        print(end - start)
        print()

    def add(graph, node_count):
        for _ in range(node_count):
            graph.add_node(np.random.rand(10))

    nsw1 = NSWGraph()
    nsw2 = NSWGraph()
    nsw3 = NSWGraph()

    add(nsw1, 1000)
    add(nsw2, 2000)
    add(nsw3, 4000)

    time(nsw1)
    time(nsw2)
    time(nsw3)

test()

1000
0:00:00.108000

2000
0:00:00.214994

4000
0:00:00.432016

