Rewritten NNDescent algorithm using matrices instead of dictionaries.

TODO: Remove for-loops where possible.

In [154]:
import torch
import time
#import matplotlib.pyplot as plt
#from collections import defaultdict

class NNDescent:
  def __init__(self, data, k=2, trees=10, init=2):
    # The data is a (N x d) matrix with N instances of d-dimensional points
    self.data = data
    N = data.shape[0]
    self.k = k
    self.numtrees = trees
    
    # A 2D tensor representing a directed graph.
    # The value a = graph[i,j] represents an edge from point x_i to x_a.
    self.graph = torch.zeros(size=[N, k], dtype=torch.long)
    
    # Initialize graph randomly or with forest
    if init == 1:
      self.initialize_graph_randomly()
    elif init == 2:
      self.initialize_graph_big_random(data)
    elif init == 3:
      self.initialize_graph_forest(data)

    # A set of tuples (i,j) of indices for which the distance has already been calculated.
    self.explored_edges = set()

    # A 2D tensor representing the distance between point x_i and x_graph[i,j]
    self.k_distances = torch.zeros([N, k])
    self.calculate_all_distances()
    

  def initialize_graph_randomly(self):
    '''
    Initializes self.graph with random values such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    # Initialize graph randomly, removing self-loops
    self.graph = torch.randint(high = N-1, size=[N,k], dtype=torch.long)
    row_indices = torch.arange(N).unsqueeze(1).repeat(1,k)
    self.graph[self.graph>=row_indices] += 1

  def initialize_graph_big_random(self, data):
    '''
    Initializes self.graph randomly, but with more neighbours at the start
    '''
    N, k = self.graph.shape
    temp_graph = torch.tensor([])
        
    # make 'trees', combine into giant graph with each element (row) having k * num_trees neighbours
    # this is a small for loop - numtrees and k << datapoints
    for j in range(self.numtrees):
      tree_graph = torch.tensor([])
      for i in range(k):
        tree_graph = torch.cat((tree_graph,torch.randperm(N)),0) # generate randomly shuffled list of N indices
      tree_graph = tree_graph.reshape(-1,k) # creates a N x k tensor with N indices, each appearing k times. This represents 1 'tree'
      temp_graph = torch.cat((temp_graph,tree_graph),1) # combine into giant N x (k*num_trees) tensor. This represents the forest
    
    # find KNN for each row in giant graph
    # TODO - implement the below without a for loop
    for i, row in enumerate(temp_graph):
      temp_row = torch.unique(row).type(torch.LongTensor) # remove duplicates
      temp_row = temp_row[temp_row != i] # remove self
      
      temp_points = data[temp_row,:] # pick out elements from dataset
      distances = dist_bulk(temp_points,data[i]) # Euclidean distances
      indices = distances.topk(k=self.k, largest=False).indices # find indices of KNN
      self.graph[i] = temp_row[indices] # assign KNN to graph
      
  def initialize_graph_forest(self, data):
    '''
    Initializes self.graph with a forest of random trees, such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    
    temp_graph = torch.tensor(())
    for j in range(self.numtrees):
      # Create trees, obtain leaves
      t = tree(data, k = k)
      
      # Create temporary graph, 1 for each tree
      # Leaves are of uneven size; select smallest leaf size as graph size
      cols = min([len(leaf) for leaf in t.leaves])
      rows = len(t.leaves)
      tree_graph = torch.zeros((N, cols))
      leaves = torch.tensor(())
      idx_update = torch.tensor(())
      
      # Update graph using leaves
      for leaf in t.leaves:
        temp_idx = torch.as_strided(torch.tensor(leaf).repeat(1,2),size=[len(leaf),cols],stride=[1,1],storage_offset=1)
        tree_graph[leaf,:] = temp_idx.float() # update graph. a lot of overwriting
#         idx_update = torch.cat((idx_update,temp_idx), 0)
#         leaves = torch.cat((leaves,torch.tensor(leaf)), 0)
#       leaves = leaves.unique().long()
#       tree_graph[leaves,:] = idx_update.float()
      # Concatenate all graphs from all trees into 1 giant graph
      temp_graph = torch.cat((temp_graph,tree_graph), 1)
    
    warning_count = 0 # number of indices for which some neighbours are random
    # find KNN for each row in giant graph
    # TODO - implement the below without a for loop
    for i, row in enumerate(temp_graph):
      temp_row = torch.unique(row).type(torch.LongTensor) # remove duplicates
      temp_row = temp_row[temp_row != i] # remove self
      
      temp_points = data[temp_row,:] # pick out elements from dataset
      d=((data[i].reshape(1,4).unsqueeze(1)-temp_points.unsqueeze(0))**2).sum(-1)
      distances, indices = torch.sort(d,dim=1)
      indices = indices.flatten()[:k]
      
      indices = temp_row[indices]
      
      # pad with random indices if there are not enough neighbours
      warning = False # warning flag
      while len(indices) < k:
        pad = torch.randint(high = N-1, size=[k-len(indices),], dtype=torch.long)
        indices = torch.cat((indices,pad))
        indices = torch.unique(indices).type(torch.LongTensor) # remove duplicates
        indices = indices[indices != i] # remove self
        warning = True

      self.graph[i] = indices # assign KNN to graph
      
      if warning:
        warning_count += 1
    
    if warning_count:
      print("WARNING!",warning_count," INDICES ARE RANDOM!")
        
  def calculate_all_distances(self):
    '''
    Updates the distances (self.k_distances) of the edges found in self.graph.
    '''
    # Note: Start with for loop for simplicity. TODO: Try to remove loop.
    for i, row in enumerate(self.graph):
      # Indices of current k neighbors in self.graph
      neighbor_indices = [(i,int(r)) for r in row]

      # The distances of those neighbors are saved in k_distances
      self.k_distances[i] = torch.Tensor([dist(self.data[a],self.data[b]) for a,b in neighbor_indices])

      # Add pairs to explored_edges set
      self.explored_edges.update(neighbor_indices) 
    

  def update_graph(self, iter=5):
    '''
      Updates the graph using algorithm: https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html
    '''
    # [STEP 1: Start with random graph.] Iterate
    start = time.time()
    for it in range(iter):
      print("Iteration number",it,"with average distance of",torch.mean(self.k_distances).item(),"Took", time.time()-start,"seconds.")
      has_changed = False

      # [STEP 2: For each node:] (TODO: Investigate whether this can be vectorized.)
      for i, neighbors in enumerate(self.graph):
        # Distances of current neighbors
        dist_current_neighbors = self.k_distances[i]

        # [STEP 3: Measure distance from the node to the neighbors of its neighbors]
        # Find neighbors of neighbors
        potential_neighbors = {a.item() for a in self.graph[neighbors].flatten() \
                               if a not in neighbors and a!=i and (i,int(a)) not in self.explored_edges}
        potential_distances = torch.Tensor([dist(self.data[i],self.data[n]) for n in potential_neighbors])
        self.explored_edges.update([(i,int(r)) for r in potential_neighbors])

        # Concatenate potential neighbors to list of neighbors (indices and distances)
        cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
        cat_dist = torch.cat([self.k_distances[i], potential_distances])

        # [STEP 4: If any are closer, then update the graph accordingly, and only keep the k closest]
        # Sort using torch.sort(), which also returns sorted indices
        dist_sorted, idx = torch.sort(cat_dist)
        if torch.max(idx[:self.k]) >= self.k:
          has_changed = True
          self.graph[i] = cat_idx[idx[:self.k]]
          self.k_distances[i] = dist_sorted[:self.k]
        
      # [STEP 5: If any changes were made, repeat iteration, otherwise stop]
      if not has_changed:
        print("Nothing changed in iteration",it)
        break
    print("Done.")

  def k_nearest_graph_search(self,x):
    '''
    Gets the k nearest neighbors of input x according to the graph, using this algorithm:
      https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html#Searching-using-a-nearest-neighbor-graph
    Input: 
      x - a torch tensor of shape (1,d), where d is the number of dimentions of the input data.
      TODO: Add support for batches, so x can be of shape (n, d). Not sure how to batch-ify the graph search...
    Output:
      The indices of the k nearest neighbors according using graph search with random initialization.
    '''

    # Random initialization for starting point of search
    random_start = torch.randint(len(self.data),size=[1], dtype=torch.long)
    # A list of random initialization and its neighbors
    candidate_idx = torch.cat([self.graph[random_start].squeeze(),random_start], dim=0) 
    # Track the nodes we have explored already
    explored = [random_start]
    
    count = 0
    while True:
      count += 1
      # [2. Look at nodes connected by an edge to the best untried node in graph]
      unexplored = [i for i in candidate_idx if i not in explored]
      if not unexplored:
        # if nothing is unexplored, the search is over
        break 
    
      # Add neighbors of the first unexplored point to the list of candidates
      candidate_idx = torch.cat([self.graph[unexplored[0]],candidate_idx], dim=0) 
      # and mark it as explored
      explored = explored + [unexplored[0]]
    
      # [4. Sort by closeness]
      distances = ((self.data[candidate_idx].unsqueeze(1) - x.unsqueeze(0))**2).sum(-1).squeeze() 
      sorted, idx = torch.sort(distances,dim=0)
      candidate_idx = candidate_idx[idx]
    
      # [5. Truncate to k best]
      # TODO: use Hudson's unwanted_indices method from "Loop replacement.ipynb" to get unique values.
      # It might be faster and is probably more useful when algo supports predictions for batches.
      candidate_idx = torch.unique_consecutive(candidate_idx)[:k] # unique_consecutive doesn't sort
    
      # [6. Return to step 2. If we have already tried all candidates in pool, we stop in the if not unexplored]
    
    # Return the k candidates
    print("Graph search finished after",count,"steps")
    return candidate_idx

  def predict(self,x):
    '''
    Predict output using tree. Hasn't been implemented yet. Needs labels y.
    '''
    pass
    
def dist(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  return torch.sum((x-y)**2)

def dist_bulk(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  # For datasets
  return ((x-y)**2).sum(-1)

class tree:
  '''
  Random projection tree class that splits the data evenly per split
  Each split is performed by calculating the projection distance of each datapoint to a random unit vector
  The datapoints are then split by the median of of these projection distances
  The indices of the datapoints are stored in tree.leaves, as a nested list
  '''
  def __init__(self, x, k=5):
    self.min_size = k
    self.leaves = []
    self.sizes = []
    indices = torch.arange(x.shape[0])
    self.tree = self.make_tree(x, indices)

  def make_tree(self, x, indices):
    if x.shape[0] > self.min_size:
      v = self.choose_rule(x)
      distances = torch.tensordot(x,v,dims=1) # create list of projection distances
      median = torch.median(distances)
      left_bool = distances <= median # create boolean array where entries are true if distance <= median
      right_bool = ~left_bool # inverse of left_bool
      left_indices = indices[left_bool]
      right_indices = indices[right_bool]
      self.make_tree(x[left_bool,:],left_indices)
      self.make_tree(x[right_bool,:],right_indices)
    elif x.shape[0] != 0:
      self.leaves.append(indices.tolist())
      self.sizes.append(x.shape[0])
    return

  def choose_rule(self, x):
    dim = x.shape[1]
    v = torch.rand(dim) # create random vector
    v /= torch.norm(v) # normalize to unit vector
    return v
    
def check_accuracy(data, graph, k_distances):
  '''
  Takes in data and graph to check accuracy of graph's assigned k nearest neighbours
  Uses torch brute force to find actual k nearest neighbours
  Returns accuracy: proportion of correct nearest neighbours
  Also returns distance error: (average_distance-true_distances)/true_distance (of k nearest neighbours)
  '''
  N, k = graph.shape
  graph = torch.sort(graph,dim=1)[0] # sort each row of graph

  # Calculate true distances, indices
  d=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)+torch.Tensor([float('inf')]).repeat(len(data)).diag() # Infinity is added to diagonal
  true_distances, true_indices = torch.sort(d,dim=1)

  # get k nearest neighbours
  true_indices = true_indices[:,:k]
  true_distances = true_distances[:,:k]
  
  # Calculate number of correct nearest neighbours
  accuracy = 0
  for i in range(k):
    accuracy += torch.sum(graph == true_indices).float()
    torch.roll(true_indices, 1, -1) # Create a rolling window (index positions may not match)
  accuracy = float(accuracy/(N*k)) # percentage accuracy

  # Calculate accuracy of distances
  true_average = torch.mean(true_distances)
  graph_average = torch.mean(k_distances)
  distance_error = float((graph_average-true_average)/true_average)

  return accuracy, distance_error

In [178]:
# Testing out NNDescent class
# torch.manual_seed(1)
data = torch.Tensor([[1.0,1.0], [2.0,1.0], [3.0,1.0], [4.0,1.0],
                     [1.0,2.0], [2.0,2.0], [3.0,2.0], [4.0,2.0]])  
data = torch.randn(size=[10000,4])
# print(data)  

torch.set_printoptions(threshold=10)

k = 3

# Initialize NNDescent graph randomly

print("Initializing randomly...")
start = time.time()
n = NNDescent(data, k=k, init=1)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
accuracy, distance_error = check_accuracy(data, n.graph, n.k_distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')



# Initialize NNDescent graph with big random
print("Initializing big random...")
start = time.time()
n = NNDescent(data, k=k, init=2)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
accuracy, distance_error = check_accuracy(data, n.graph, n.k_distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')



# Initialize NNDescent graph with forest
print("Initializing forest...")
start = time.time()
n = NNDescent(data, k=k, init=3)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
accuracy, distance_error = check_accuracy(data, n.graph, n.k_distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')



# Brute force search
print("Brute Force")
start = time.time()
m=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)+torch.Tensor([float('inf')]).repeat(len(data)).diag() # Infinity is added to diagonal
distances, brute_force = torch.topk(m,k=k,dim=1,largest=False)
print("Took", time.time()-start,"seconds.\n")
print(brute_force)
print('mean distance',(distances).sqrt().mean())
accuracy, distance_error = check_accuracy(data, brute_force, distances)
print("Accuracy: ",accuracy)
print("Distance Error: ",distance_error,'\n')

# Get k nearest neighbors using graph search
print()
print("Graph search")
#x = torch.Tensor([4, 0]).unsqueeze(0) # size 1 x d
x = torch.randn(size=[1,data.shape[1]])
print("x is",x)
start = time.time()
k_nearest = n.k_nearest_graph_search(x)
print("Took", time.time()-start,"seconds.")
print("Nearest indices with graph search:",k_nearest)
print("coordinates of nearest dots:",n.data[k_nearest])

# Get k nearest using brute force knn
print("\nActual nearest using KNN")
start = time.time()
m=((data.unsqueeze(1)-x.unsqueeze(0))**2).sum(-1).squeeze() # Infinity is added to diagonal
print(m)
brute_force = torch.sort(m,dim=0)[1][:k]
print("Took", time.time()-start,"seconds.")
print("The KNN nearest are:",brute_force)
print("coordinates of nearest dots:",n.data[brute_force])

Initializing randomly...
Took 0.9426238536834717 seconds.

Graph:
tensor([[2313, 2724,  163],
        [ 814, 3297, 4559],
        [3724, 4563, 4981],
        ...,
        [3983, 3163, 1796],
        [ 972, 7119, 1581],
        [3873, 3228, 4218]])
Distances:
tensor([[2.6757, 3.7233, 2.6305],
        [3.0326, 3.8609, 3.7230],
        [2.9173, 3.6125, 3.7517],
        ...,
        [1.9602, 2.4340, 3.0428],
        [3.3533, 1.5052, 3.4768],
        [2.3655, 3.5801, 3.8535]])
Updating...

Iteration number 0 with average distance of 8.052292823791504 Took 0.0 seconds.
Iteration number 1 with average distance of 3.1890859603881836 Took 5.2776336669921875 seconds.
Iteration number 2 with average distance of 2.110304832458496 Took 11.32966947555542 seconds.
Iteration number 3 with average distance of 1.693931221961975 Took 16.122722864151 seconds.
Iteration number 4 with average distance of 1.518083930015564 Took 20.697295665740967 seconds.
Iteration number 5 with average distance of 1.4415931

In [120]:
# Sandbox for testing
i = 0
neighbors = n.graph[i]
print("Neighbors:",neighbors)
print("Neighbors of neighbors:",n.graph[neighbors].flatten())
potential_neighbors = {a.item() for a in n.graph[neighbors].flatten() if a not in neighbors and a!=i}
print("New potential neighbors:", potential_neighbors)
potential_distances = torch.Tensor([dist(data[i],data[n]) for n in potential_neighbors])
print("Potential distances:", (potential_distances))
cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
cat_dist = torch.cat([n.k_distances[i], potential_distances])
print("cat_idx:", cat_idx)
print("cat_dist:", cat_dist)
print("sort cat:", torch.sort(cat_dist))
val, idx = torch.sort(cat_dist)
print("idx max", torch.max(idx)>3)
print()
print("New neighbors:",cat_idx[idx[:3]])
print("New distances:",val[:3])

Neighbors: tensor([27, 37])
Neighbors of neighbors: tensor([ 0, 37, 46,  0])
New potential neighbors: {46}
Potential distances: tensor([1.6462])
cat_idx: tensor([27., 37., 46.])
cat_dist: tensor([0.6104, 1.0715, 1.6462])
sort cat: torch.return_types.sort(
values=tensor([0.6104, 1.0715, 1.6462]),
indices=tensor([0, 1, 2]))
idx max tensor(False)

New neighbors: tensor([27., 37., 46.])
New distances: tensor([0.6104, 1.0715, 1.6462])
