Rewritten NNDescent algorithm using matrices instead of dictionaries.

TODO: Remove for-loops where possible.

In [2]:
!pip install pykeops[colab] > install.log

In [17]:
import torch
import time
import matplotlib.pyplot  as plt
from pykeops.torch import LazyTensor
#from collections import defaultdict

use_cuda = torch.cuda.is_available()
if use_cuda:
  torch.cuda.synchronize()
  device = torch.device('cuda')
else:
  device  = torch.device('cpu')

class NNDescent:
  def __init__(self, data, k=2, trees=5, init=2, leaf_multiplier=1, LT=False):
    # The data is a (N x d) matrix with N instances of d-dimensional points
    self.data = data.to(device)
    N = data.shape[0]
    self.k = k
    self.numtrees = trees
    self.big_leaves = None
    
    # A 2D tensor representing a directed graph.
    # The value a = graph[i,j] represents an edge from point x_i to x_a.
    self.graph = torch.zeros(size=[N, k], dtype=torch.long)
    
    # Initialize graph randomly or with forest
    if init == 1:
      self.initialize_graph_randomly()
    elif init == 2:
      self.initialize_graph_big_random(data)
    elif init == 3:
      self.initialize_graph_forest(data, leaf_multiplier)
    elif init == 4:
      self.initialize_graph_clusters(data, leaf_multiplier, LT)
    
#     accuracy, _ = init_accuracy(data,self.graph,torch.zeros([N, k]))
#     print('Initial Accuracy: ',accuracy)
    
    # A set of tuples (i,j) of indices for which the distance has already been calculated.
    self.explored_edges = set()

    # A 2D tensor representing the distance between point x_i and x_graph[i,j]
    self.k_distances = torch.zeros([N, k])
    self.calculate_all_distances()
    

  def initialize_graph_randomly(self):
    '''
    Initializes self.graph with random values such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    # Initialize graph randomly, removing self-loops
    self.graph = torch.randint(high = N-1, size=[N,k], dtype=torch.long)
    row_indices = torch.arange(N).unsqueeze(1).repeat(1,k)
    self.graph[self.graph>=row_indices] += 1

  def initialize_graph_big_random(self, data):
    '''
    Initializes self.graph randomly, but with more neighbours at the start
    '''
    N, k = self.graph.shape
    temp_graph = torch.tensor([])
        
    # make 'trees', combine into giant graph with each element (row) having k * num_trees neighbours
    # this is a small for loop - numtrees and k << datapoints
    for j in range(self.numtrees):
      tree_graph = torch.tensor([])
      for i in range(k):
        tree_graph = torch.cat((tree_graph,torch.randperm(N)),0) # generate randomly shuffled list of N indices
      tree_graph = tree_graph.reshape(-1,k) # creates a N x k tensor with N indices, each appearing k times. This represents 1 'tree'
      temp_graph = torch.cat((temp_graph,tree_graph),1) # combine into giant N x (k*num_trees) tensor. This represents the forest
    
    # find KNN for each row in giant graph
    # TODO - implement the below without a for loop
    for i, row in enumerate(temp_graph):
      temp_row = torch.unique(row).type(torch.LongTensor) # remove duplicates
      temp_row = temp_row[temp_row != i] # remove self
      
      temp_points = data[temp_row,:] # pick out elements from dataset
      distances = dist_bulk(temp_points,data[i]) # Euclidean distances
      indices = distances.topk(k=self.k, largest=False).indices # find indices of KNN
      self.graph[i] = temp_row[indices] # assign KNN to graph
      
  def initialize_graph_forest(self, data, leaf_multiplier):
    '''
    Initializes self.graph with a forest of random trees, such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    dim = data.shape[1]
    
    temp_graph = torch.tensor(())
    for j in range(self.numtrees):
      # Create trees, obtain leaves
      t = tree(data, k = k*leaf_multiplier)
      
      # Create temporary graph, 1 for each tree
      # Leaves are of uneven size; select smallest leaf size as graph size
      cols = min([len(leaf) for leaf in t.leaves])
      rows = len(t.leaves)
      tree_graph = torch.zeros((N, cols))
      leaves = torch.tensor(())
      idx_update = torch.tensor(())
      
      # Update graph using leaves
      for leaf in t.leaves:
        temp_idx = torch.as_strided(torch.tensor(leaf).repeat(1,2),size=[len(leaf),cols],stride=[1,1],storage_offset=1)
        tree_graph[leaf,:] = temp_idx.float() # update graph. a lot of overwriting
      # Concatenate all graphs from all trees into 1 giant graph
      temp_graph = torch.cat((temp_graph,tree_graph), 1)

      # Add the first tree's big_leaves to the NNDescent's big_leaves
      if j==0:
          self.big_leaves = torch.LongTensor(t.big_leaves)
    
    warning_count = 0 # number of indices for which some neighbours are random
    # find KNN for each row in giant graph
    # TODO - implement the below without a for loop
    for i, row in enumerate(temp_graph):
      temp_row = torch.unique(row).type(torch.LongTensor) # remove duplicates
      temp_row = temp_row[temp_row != i] # remove self
      
      temp_points = data[temp_row,:] # pick out elements from dataset
      d=((data[i].reshape(1,dim).unsqueeze(1)-temp_points.unsqueeze(0))**2).sum(-1)
      distances, indices = torch.sort(d,dim=1)
      indices = indices.flatten()[:k]
      
      indices = temp_row[indices]
      
      # pad with random indices if there are not enough neighbours
      warning = False # warning flag
      while len(indices) < k:
        pad = torch.randint(high = N-1, size=[k-len(indices),], dtype=torch.long)
        indices = torch.cat((indices,pad))
        indices = torch.unique(indices).type(torch.LongTensor) # remove duplicates
        indices = indices[indices != i] # remove self
        warning = True

      self.graph[i] = indices # assign KNN to graph
      
      if warning:
        warning_count += 1
    
    if warning_count:
      print("WARNING!",warning_count," INDICES ARE RANDOM!")

      
  def initialize_graph_clusters(self, data, leaf_multiplier, LT=False):
    N,dim = data.shape
    k = self.k
    self.clusters = torch.ones(N,)*-1

    data = data.to(device)
    
    # Create trees, obtain leaves
    t = tree(data, k, leaf_multiplier, LT)
    
    # Assign each point to a cluster, 1 cluster per tree in forest
    for i, leaf in enumerate(t.leaves):
      self.clusters[leaf] = i
    self.centroids = t.centroids.clone()
      
    # Find nearest centroids
    x_LT=LazyTensor(self.centroids.unsqueeze(1))
    y_LT=LazyTensor(self.centroids.unsqueeze(0))
    d=((x_LT-y_LT)**2).sum(-1)
    indices = d.argKmin(K=k+1,dim=1).long()
    self.centroids_neighbours = indices[:,1:]
    
    self.graph = self.centroids_neighbours
    
    # Assign big_leaves by searching for the correct cluster
    self.big_leaves = torch.LongTensor(t.big_leaves)
    for i, index in enumerate(self.big_leaves):
      self.big_leaves[i] = self.clusters[index]
    return
      
  def calculate_all_distances(self):
    '''
    Updates the distances (self.k_distances) of the edges found in self.graph.
    '''
    # Note: Start with for loop for simplicity. TODO: Try to remove loop.
    for i, row in enumerate(self.graph):
      # Indices of current k neighbors in self.graph
      neighbor_indices = [(i,int(r)) for r in row]

      # The distances of those neighbors are saved in k_distances
      self.k_distances[i] = torch.Tensor([dist(self.data[a],self.data[b]) for a,b in neighbor_indices])

      # Add pairs to explored_edges set
      self.explored_edges.update(neighbor_indices) 
    

  def update_graph(self, iter=5):
    '''
      Updates the graph using algorithm: https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html
    '''
    # [STEP 1: Start with random graph.] Iterate
    start = time.time()
    for it in range(iter):
#       print("Iteration number",it,"with average distance of",torch.mean(self.k_distances).item(),"Took", time.time()-start,"seconds.")
      has_changed = False

      # [STEP 2: For each node:] (TODO: Investigate whether this can be vectorized.)
      for i, neighbors in enumerate(self.graph):
        # Distances of current neighbors
        dist_current_neighbors = self.k_distances[i]

        # [STEP 3: Measure distance from the node to the neighbors of its neighbors]
        # Find neighbors of neighbors
        potential_neighbors = {a.item() for a in self.graph[neighbors].flatten() \
                               if a not in neighbors and a!=i and (i,int(a)) not in self.explored_edges}
        potential_distances = torch.Tensor([dist(self.data[i],self.data[n]) for n in potential_neighbors])
        self.explored_edges.update([(i,int(r)) for r in potential_neighbors])

        # Concatenate potential neighbors to list of neighbors (indices and distances)
        cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
        cat_dist = torch.cat([self.k_distances[i], potential_distances])

        # [STEP 4: If any are closer, then update the graph accordingly, and only keep the k closest]
        # Sort using torch.sort(), which also returns sorted indices
        dist_sorted, idx = torch.sort(cat_dist)
        if torch.max(idx[:self.k]) >= self.k:
          has_changed = True
          self.graph[i] = cat_idx[idx[:self.k]]
          self.k_distances[i] = dist_sorted[:self.k]
        
      # [STEP 5: If any changes were made, repeat iteration, otherwise stop]
#       if not has_changed:
#         print("Nothing changed in iteration",it)
#         break
#     print("Done.")

  def k_nearest_graph_search(self,X,max_num_steps = 100, tree_init = True):
    '''
    Gets the k nearest neighbors of input x according to the graph, using this algorithm:
      https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html#Searching-using-a-nearest-neighbor-graph
    Input: 
      x - a torch tensor of shape (N,d), where N is the number of dimentions of the input data.
    Output:
      The indices of the k nearest neighbors according using graph search with random initialization.
    '''
    # N datapoints of dimension d
    N, d = X.shape
    X = X.to(device)
    
    k = self.k
    # Boolean mask to keep track of those points whose search is still ongoing
    is_active = (torch.ones(N)==1).to(device)
    
    # If graph was initialized using trees, we can use information from there to initialize in a diversed manner.
    if self.big_leaves is not None and tree_init:
        candidate_idx = self.big_leaves.unsqueeze(0).repeat(N,1).to(device) # Shape: (N,32)
    else:
        # Random initialization for starting points of search. 
        candidate_idx = torch.randint(high=self.graph.shape[0], size=[N, k+1], dtype=torch.long).to(device)
  
    # Sort the candidates by distance from X
    distances = ((self.centroids[candidate_idx] - X.unsqueeze(1))**2).sum(-1)
    sorted, idx = torch.sort(distances, dim=1)
    candidate_idx = torch.gather(candidate_idx, dim=1,index=idx)
    # Truncate to k+1 nearest
    candidate_idx = candidate_idx[:,:(k+1)]
    
    # Track the nodes we have explored already, in N x num_explored tensor
    num_explored = self.k*2
    explored = torch.full(size=[N,num_explored],fill_value=-1).to(device)

    start = time.time()
    # The initialization of candidates and explored set is done. Now we can search.
    count = 0
    while count < max_num_steps:
        print("Step",count,"- Search is completed for",1-torch.mean(1.0*is_active).item(),"- this step took",time.time()-start,"s")
        start=time.time()
    
        # [2. Look at nodes connected by an edge to the best untried node in graph]
        # diff_bool.shape is (M, k+1, num_explored), where M is the number of active searches
        diff_bool = (candidate_idx[is_active].unsqueeze(2) - explored[is_active].unsqueeze(1) == 0)
        in_explored = torch.any(diff_bool, dim=2)
        # batch_active is true for those who haven't been fully explored in the current batch
        batch_active = ~torch.all(in_explored[:,:-1], dim=1)

        # Update is_active mask. If none are active, break search
        is_active[is_active.clone()] = batch_active
        if not is_active.any():
            break

        # first_unexplored has indices of first unexplored element per row
        first_unexplored = torch.max(~in_explored[batch_active],dim=1)[1].unsqueeze(1) 
        # Unexplored nodes to be expanded
        unexplored_idx = torch.gather(candidate_idx[is_active], dim=1, index=first_unexplored).squeeze(-1)
        explored[is_active,(count%num_explored)] = unexplored_idx
    
        # [3. Add all these nodes to our potential candidate pool]
        # Add neighbors of the first unexplored point to the list of candidates
        expanded_idx = torch.cat((self.graph[unexplored_idx],candidate_idx[is_active]), dim=1) 
    
        # To avoid repeats, we use Hudson's unwanted_indices method to find repeats
        expanded_idx = torch.sort(expanded_idx)[0]
        shift = torch.cat((torch.full((len(expanded_idx),1),-1).to(device),torch.sort(expanded_idx,dim=1)[0][:,:-1]),dim=1)
        unwanted_indices = (expanded_idx==shift)
    
        # [4. Sort by closeness].
        distances = ((self.centroids[expanded_idx] - X[is_active].unsqueeze(1))**2).sum(-1)
        distances[unwanted_indices] += float('inf')
        sorted, idx = torch.sort(distances,dim=1)
        expanded_idx = torch.gather(expanded_idx,dim=1,index=idx)
    
        # [5. Truncate to k+1 best]
        candidate_idx[is_active] = expanded_idx[:,:(self.k+1)] 
        
        # [6. Return to step 2. If we have already tried all candidates in pool, we stop in the if not unexplored]
        count += 1
    
    # Return the k candidates
    print("Graph search finished after",count,"steps. Finished for:",1-torch.mean(1.0*is_active).item())
    return candidate_idx[:,:-1]

#   def k_nearest_graph_search(self,X,max_num_steps = 100, tree_init = True):
#     '''
#     Gets the k nearest neighbors of input x according to the graph, using this algorithm:
#       https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html#Searching-using-a-nearest-neighbor-graph
#     Input: 
#       x - a torch tensor of shape (N,d), where N is the number of dimentions of the input data.
#     Output:
#       The indices of the k nearest neighbors according using graph search with random initialization.
#     '''

#     # N datapoints of dimension d
#     N, d = X.shape
#     k = self.k
#     # Boolean mask to keep track of those points whose search is still ongoing
#     is_active = (torch.ones(N)==1)
    
#     # If graph was initialized using trees, we can use information from there to initialize in a diversed manner.
#     if self.big_leaves is not None and tree_init:
#         candidate_idx = self.big_leaves.unsqueeze(0).repeat(N,1) # Shape: (N,32)
#     else:
#         # Random initialization for starting points of search. 
#         candidate_idx = torch.randint(high=len(self.data),size=[N, k+1], dtype=torch.long)
    
#     # Sort the candidates by distance from X
#     distances = ((self.data[candidate_idx] - X.unsqueeze(1))**2).sum(-1)
#     sorted, idx = torch.sort(distances, dim=1)
#     candidate_idx = torch.gather(candidate_idx, dim=1,index=idx)
#     # Truncate to k+1 nearest
#     candidate_idx = candidate_idx[:,:(k+1)]
    
#     # Track the nodes we have explored already, in N x num_explored tensor
#     num_explored = self.k*2
#     explored = torch.full(size=[N,num_explored],fill_value=-1)

#     start = time.time()
#     # The initialization of candidates and explored set is done. Now we can search.
#     count = 0
#     while count < max_num_steps:
#         print("Step",count,"- Search is completed for",1-torch.mean(1.0*is_active).item(),"- this step took",time.time()-start,"s")
#         start=time.time()
    
#         # [2. Look at nodes connected by an edge to the best untried node in graph]
#         # diff_bool.shape is (M, k+1, num_explored), where M is the number of active searches
#         diff_bool = (candidate_idx[is_active].unsqueeze(2) - explored[is_active].unsqueeze(1) == 0)
#         in_explored = torch.any(diff_bool, dim=2)
#         # batch_active is true for those who haven't been fully explored in the current batch
#         batch_active = ~torch.all(in_explored[:,:-1], dim=1)

#         # Update is_active mask. If none are active, break search
#         is_active[is_active] = batch_active
#         if not is_active.any():
#             break

#         # first_unexplored has indices of first unexplored element per row
#         first_unexplored = torch.max(~in_explored[batch_active],dim=1)[1].unsqueeze(1) 
#         # Unexplored nodes to be expanded
#         unexplored_idx = torch.gather(candidate_idx[is_active], dim=1, index=first_unexplored).squeeze(-1)
#         explored[is_active,(count%num_explored)] = unexplored_idx
    
#         # [3. Add all these nodes to our potential candidate pool]
#         # Add neighbors of the first unexplored point to the list of candidates
#         expanded_idx = torch.cat((self.graph[unexplored_idx],candidate_idx[is_active]), dim=1) 
    
#         # To avoid repeats, we use Hudson's unwanted_indices method to find repeats
#         expanded_idx = torch.sort(expanded_idx)[0]
#         shift = torch.cat((torch.full((len(expanded_idx),1),-1),torch.sort(expanded_idx,dim=1)[0][:,:-1]),dim=1)
#         unwanted_indices = (expanded_idx==shift)
    
#         # [4. Sort by closeness].
#         distances = ((self.data[expanded_idx] - X[is_active].unsqueeze(1))**2).sum(-1)
#         distances[unwanted_indices] += float('inf')
#         sorted, idx = torch.sort(distances,dim=1)
#         expanded_idx = torch.gather(expanded_idx,dim=1,index=idx)
    
#         # [5. Truncate to k+1 best]
#         candidate_idx[is_active] = expanded_idx[:,:(self.k+1)] 
        
#         # [6. Return to step 2. If we have already tried all candidates in pool, we stop in the if not unexplored]
#         count += 1
    
#     # Return the k candidates
#     print("Graph search finished after",count,"steps. Finished for:",1-torch.mean(1.0*is_active).item())
#     return candidate_idx[:,:-1]

  def predict(self,x):
    '''
    Predict output using tree. Hasn't been implemented yet. Needs labels y.
    '''
    pass
    
    
def dist(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  return torch.sum((x-y)**2)

def dist_bulk(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  # For datasets
  return ((x-y)**2).sum(-1)

def dot_product(x, v):
  # Calculate dot product between matrix x and vector v using LazyTensors
  v_LT = LazyTensor(v.view(1,1,-1))
  x_LT = LazyTensor(x.unsqueeze(0))
  return (v_LT|x_LT).sum_reduction(axis=0).flatten()

class tree: # NN clusters tree
  '''
  Random projection tree class that splits the data evenly per split
  Each split is performed by calculating the projection distance of each datapoint to a random unit vector
  The datapoints are then split by the median of of these projection distances
  The indices of the datapoints are stored in tree.leaves, as a nested list
  '''
  def __init__(self, x, k=5, leaf_multiplier=1, LT=False):
    self.min_size = k * leaf_multiplier
    self.leaves = []
    self.sizes = []
    self.centroids = torch.tensor(()).to(device)
    self.big_leaves = [] # leaves at depth = 5
    indices = torch.arange(x.shape[0])

    self.dim = x.shape[1]
    self.data = x.to(device)
    self.LT = LT # Boolean to choose LT or torch initialization

    self.tree = self.make_tree(indices, depth = 0)
    self.centroids = self.centroids.reshape(-1,x.shape[1])

  def make_tree(self, indices, depth):
    if depth == 5: # add to big_leaves if depth=5
      self.big_leaves.append(int(indices[0]))
    if indices.shape[0] > self.min_size:
      v = self.choose_rule().to(device)

      if self.LT:
        distances = dot_product(self.data[indices],v) # create list of projection distances
      else:
        distances = torch.tensordot(self.data[indices],v,dims=1) # create list of projection distances

      median = torch.median(distances)
      left_bool = distances <= median # create boolean array where entries are true if distance <= median
      # right_bool = ~left_bool # inverse of left_bool
      # left_indices = indices[left_bool]
      # right_indices = indices[right_bool]
      self.make_tree(indices[left_bool], depth+1)
      self.make_tree(indices[~left_bool], depth+1)
    elif indices.shape[0] != 0:
      self.leaves.append(indices.tolist())
      self.sizes.append(indices.shape[0])
      centroid = self.data.mean(dim=0) # get centroid position
      self.centroids = torch.cat((self.centroids,centroid))
    return

  def choose_rule(self):
    v = torch.rand(self.dim) # create random vector
    v /= torch.norm(v) # normalize to unit vector
    return v
    
def init_accuracy(data, graph, k_distances):
  '''
  Takes in data and graph to check accuracy of graph's assigned k nearest neighbours
  Uses torch brute force to find actual k nearest neighbours
  Returns accuracy: proportion of correct nearest neighbours
  Also returns distance error: (average_distance-true_distances)/true_distance (of k nearest neighbours)
  '''
  N, k = graph.shape

  # Calculate true distances, indices
  d=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)+torch.Tensor([float('inf')]).repeat(len(data)).diag() # Infinity is added to diagonal
  true_distances, true_indices = torch.sort(d,dim=1)

  # get k nearest neighbours
  true_indices = true_indices[:,:k]
  true_distances = true_distances[:,:k]
  
  # Calculate number of correct nearest neighbours
  accuracy = 0
  for i in range(k):
    accuracy += torch.sum(graph == true_indices).float()
    true_indices = torch.roll(true_indices, 1, -1) # Create a rolling window (index positions may not match)
  accuracy = float(accuracy/(N*k)) # percentage accuracy

  # Calculate accuracy of distances
  true_average = torch.mean(true_distances)
  graph_average = torch.mean(k_distances)
  distance_error = float((graph_average-true_average)/true_average)

  return accuracy, distance_error

def knn_accuracy_x_y(indices_test, x, y):
  '''
  Compares the test and ground truth indices (rows = KNN for each point in dataset)
  Returns accuracy: proportion of correct nearest neighbours
  
  indices_test: k nearest neighbour indices (rows = KNN for each query point)
  x: training points that the model is fitted to
  y: query points for which we want to find the nearest KNN (chosen from x)
  '''
  N, k = indices_test.shape
  
  indices_truth = torch.argsort(((y.unsqueeze(1)-x.unsqueeze(0))**2).sum(-1),dim=1)
  indices_truth = indices_truth[:,:k].to(device)
  
  # Calculate number of correct nearest neighbours
  accuracy = 0
  for i in range(k):
    accuracy += torch.sum(indices_test == indices_truth).float()/N
    indices_truth = torch.roll(indices_truth, 1, -1) # Create a rolling window (index positions may not match)
  accuracy = float(accuracy/k) # percentage accuracy

  return accuracy

In [None]:
# Testing out NNDescent class
# torch.manual_seed(1)
data = torch.Tensor([[1.0,1.0], [2.0,1.0], [3.0,1.0], [4.0,1.0],
                     [1.0,2.0], [2.0,2.0], [3.0,2.0], [4.0,2.0]])  
data = torch.randn(size=[1000,4])
# print(data)  

torch.set_printoptions(threshold=10)

k = 5

# Initialize NNDescent graph randomly

print("Initializing RANDOMLY...")
start = time.time()
n = NNDescent(data, k=k, init=1)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
# print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
accuracy, distance_error = init_accuracy(data, n.graph, n.k_distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')



# Initialize NNDescent graph with big random
print("Initializing BIG RANDOM...")
start = time.time()
n = NNDescent(data, k=k, init=2)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
accuracy, distance_error = init_accuracy(data, n.graph, n.k_distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')



# Initialize NNDescent graph with forest
print("Initializing FOREST...")
start = time.time()
n = NNDescent(data, k=k, init=3)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
accuracy, distance_error = init_accuracy(data, n.graph, n.k_distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')



# Brute force search
print("BRUTE FORCE")
start = time.time()
m=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)#+torch.Tensor([float('inf')]).repeat(len(data)).diag() # Infinity is added to diagonal
distances, brute_force = torch.sort(m,dim=1)
brute_force = brute_force[:,1:k+1]
distances = distances[:,1:k+1]
  
print("Took", time.time()-start,"seconds.\n")
print(brute_force)
print('mean distance',(distances).sqrt().mean())
accuracy, distance_error = init_accuracy(data, brute_force, distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')

# Get k nearest neighbors using graph search
print()
print("Graph search")
#x = torch.Tensor([4, 0]).unsqueeze(0) # size 1 x d
X = torch.randn(size=[1000,data.shape[1]])
print("X is",X)
start = time.time()
k_nearest = n.k_nearest_graph_search(X)
print("Took", time.time()-start,"seconds.")
print("Nearest indices with graph search:\n",k_nearest)

print("K is",k)
# Get k nearest using brute force knn
print("\nActual nearest using KNN")
print("Brute force search:") # Crashes for 1 million points +
start = time.time()
m=((data.unsqueeze(0)-X.unsqueeze(1))**2).sum(-1)
distances, brute_force = torch.topk(m,k=k,dim=1,largest=False)
print("Took", time.time()-start,"seconds.")
print("Nearest indices with brute force:\n",brute_force)


Initializing RANDOMLY...
Took 0.2392749786376953 seconds.

Graph:
tensor([[265,  71, 881, 660, 148],
        [226, 608, 513, 318, 311],
        [822,  88, 574, 882, 382],
        ...,
        [546, 931, 302, 135, 636],
        [296,  58, 910, 340, 186],
        [428, 503, 830, 639, 241]])
tensor([[2.1963, 2.7346, 3.5557, 2.4839, 3.8999],
        [1.1433, 0.8892, 2.7923, 3.1788, 1.6471],
        [3.5750, 4.0008, 3.1970, 3.1660, 2.2671],
        ...,
        [1.5027, 0.5909, 2.5328, 2.3769, 2.4326],
        [2.9869, 2.5517, 1.8289, 1.1147, 1.4828],
        [2.0158, 1.3755, 2.8946, 3.5899, 2.8191]])
Updating...

Took 22.09316849708557 seconds.

Graph:
tensor([[ 50, 149, 320, 378, 205],
        [478, 447, 118, 532, 293],
        [940, 819, 748, 635, 504],
        ...,
        [736, 275, 717, 485, 434],
        [988, 790, 634,  31, 411],
        [463, 810, 310, 384, 751]])
Distances:
tensor([[0.6784, 0.6786, 0.7200, 0.7867, 0.8495],
        [0.5603, 0.7698, 0.7900, 0.8069, 0.8143],
        

In [18]:
# Comparing the graph search accuracy of tree initialization
x = torch.randn(size=[1000,4])
n = NNDescent(x, k=6, init=4)
# n.update_graph() # no need for this anymore

k = 3 # This k can be different than the graph's k.

y = torch.randn(size=[10000,x.shape[1]])
#print("X is",X)
print("Search without tree init:")
start = time.time()
k_nearest_random = n.k_nearest_graph_search(y, tree_init = False, max_num_steps=9)
print("Took", time.time()-start,"seconds.\n")

print("Search with tree init:")
start = time.time()
k_nearest_forest = n.k_nearest_graph_search(y, tree_init = True, max_num_steps=9)
print("Took", time.time()-start,"seconds.")
print("Nearest indices with graph search:\n",k_nearest_forest)

m=((x.unsqueeze(0)-y.unsqueeze(1))**2).sum(-1)
distances, brute_force = torch.topk(m,k=k,dim=1,largest=False)
brute_force = brute_force.to(device)

print("Comparing (naively) the accuracy")
print("Random:",(1.0*(brute_force==k_nearest_random[:,:k])).mean())
print("Forest:",(1.0*(brute_force==k_nearest_forest[:,:k])).mean())

print("Accuracy Checkers")
print('Random:', knn_accuracy_x_y(k_nearest_random,x,y))
print('Forest:', knn_accuracy_x_y(k_nearest_forest,x,y))

Search without tree init:
Step 0 - Search is completed for 0.0 - this step took 0.00013256072998046875 s
Step 1 - Search is completed for 0.0 - this step took 0.0029261112213134766 s
Step 2 - Search is completed for 0.0 - this step took 0.0038721561431884766 s
Step 3 - Search is completed for 0.0 - this step took 0.003962516784667969 s
Step 4 - Search is completed for 0.0 - this step took 0.0044705867767333984 s
Step 5 - Search is completed for 0.0 - this step took 0.004760026931762695 s
Step 6 - Search is completed for 0.0 - this step took 0.00391387939453125 s
Step 7 - Search is completed for 0.003800034523010254 - this step took 0.0024039745330810547 s
Step 8 - Search is completed for 0.0658000111579895 - this step took 0.003998756408691406 s
Graph search finished after 8 steps. Finished for: 1.0
Took 0.04581880569458008 seconds.

Search with tree init:
Step 0 - Search is completed for 0.0 - this step took 8.320808410644531e-05 s
Step 1 - Search is completed for 0.0 - this step took