In [None]:
import numpy as np
import networkx as nx 
from collections import defaultdict
import time

def dist(x,y):
  return np.sum((x-y)**2)

class graph:
  def __init__(self):
    self.graph=defaultdict(set)
    self.ddict=defaultdict(float)
    self.explored_edges=defaultdict(set)
  def add_key(self,a,x):
    self.keys[a].add(x)
  def add_edge(self,a,b):
    self.graph[str(a)].add(b)
    self.graph[str(b)].add(a)
  def replace_edge_list(self,a,x):
    self.graph[str(a)]=set(x)
  def del_edge(self,a,b):
    self.graph[str(a)].remove(b)
    self.graph[str(b)].remove(a)
  def edge_list(self):
    edge=set()
    for key in set(self.graph.keys()):
      tmp=set([(int(key),i) for i in self.graph[key]])
      edge=edge.union(tmp)
    return edge
  def get_edge(self,a):
    return self.graph[str(a)]
  def visualize(self): 
        G = nx.Graph() 
        G.add_edges_from(self.edge_list()) 
        nx.draw_networkx(G) 
        plt.show() 

#construct NN graph
def construct_graph(x,k=3,count=3,init=3):
  def combi(x):
    combi=set()
    for i in x:
      for j in x:
        if (j,i) not in combi and i!=j:
          combi.add((i,j))
    return combi  
  g=graph()
  l=len(x)
  for i in range(len(x)):
    for r in range(init):
      while True:
        j=int(l * np.random.random())
        if j!=i:
          break
      g.add_edge(i,j)
      d=dist(x[i],x[j])
      g.ddict[(i,j)]=d
      g.ddict[(j,i)]=d
      g.explored_edges[i].add(j)
      g.explored_edges[j].add(i)  

  #start update here
  for i in range(count):
    neighbours_set=set()

    for index in range(l):
      el=g.get_edge(index)
      neighbours=set([j for j in combi(el) if j[1] not in g.get_edge(j[0])])
      neighbours_set=neighbours_set.union(neighbours)   
    
    if neighbours_set==set():      
      break

    for pair in list(neighbours_set):
      d=dist(x[pair[0]],x[pair[1]])
      g.ddict[pair]=d
      g.ddict[(pair[1],pair[0])]=d
      g.explored_edges[pair[0]].add(pair[1])
      g.explored_edges[pair[1]].add(pair[0])

    
    #recalculate all neighbours
    for index in range(l):
      nodes_update=list(g.explored_edges[index])
      
      dist_nodes=[g.ddict[(i,index)] for i in nodes_update] #the distances of these pairs
      d,final_nodes=[list(j) for j in list(zip(*sorted(zip(dist_nodes,nodes_update))[:k]))] #sort these pairs and take the top k pairs

      final_nodes=set(final_nodes)
      g.replace_edge_list(index,final_nodes)       
    
  return g   

def argmin(x,y,g,k=5):
  indices=[]
  d=[]
  visited=set()
  def add_index(i):
    indices.append(i)
    d.append(dist(y,x[i]))
  add_index(np.random.randint(len(x)))
  while True:
    i=None
    for index in indices:
      #select this node for expansion
      if index not in visited:
        #print(index)
        i=index
        break
    if i is None:
      return indices
    for new_node in g.get_edge(i):
      if new_node not in indices:
        add_index(new_node)    
    visited.add(i)
    d,indices=[list(i) for i in list(zip(*sorted(zip(d,indices))[:k]))]  

In [None]:
#testing large dataset
size=1000
x = np.random.randn(size, 2)
y = np.random.randn(size, 2)

def get_k_argmin(x,y,k=3):
  argmin=np.zeros([len(x),k])
  for i in range(len(y)):
    d=[dist(j,y[i]) for j in x]
    argmin[i]=np.argsort(d)[:k]
  return np.squeeze(argmin)

In [None]:
#test nndescent
nndescent=[]
start=time.time()
g=construct_graph(x,8,init=8,count=3)
print('construct NN graph time taken:',time.time()-start)
start=time.time()
for i in y:
  nndescent.append(argmin(x,i,g,3)[0])
print('NN descent search time taken',time.time()-start)
nndescent=np.squeeze(np.array(nndescent))

construct NN graph time taken: 5.320017099380493
NN descent search time taken 0.8994917869567871


In [None]:
#test argmin and argsort
npargmin=[]
npargsort=[]
start=time.time()
for j in y:
  d=[dist(i,j) for i in x]
  npargmin.append(np.argmin(d))

print('argmin search time taken',time.time()-start)
npargmin=np.array(npargmin) 
start=time.time()
for j in y:
  d=[dist(i,j) for i in x]
  npargsort.append(np.argsort(d)[0])

print('argsort search time taken',time.time()-start)
npargsort=np.array(npargsort) 

argmin search time taken 6.523342609405518
argsort search time taken 6.667661905288696


In [None]:
#calculate percentage of correct predictions
print(np.sum(npargmin==nndescent)/len(nndescent))

0.85


In [None]:
#obtain actual top 3 argmin
k_argmin=get_k_argmin(x,y,3)

In [None]:
#check if predicted argmin is within actual top 3 argmin
print(np.sum([np.sum(k_argmin[:,i]==nndescent) for i in range(3)])/len(nndescent))

0.903
