Rewritten NNDescent algorithm using matrices instead of dictionaries.

TODO: Remove for-loops where possible.

In [67]:
import torch
import time
#import matplotlib.pyplot as plt
#from collections import defaultdict

class NNDescent:
  def __init__(self, data, k=2, trees=10, random=False):
    # The data is a (N x d) matrix with N instances of d-dimensional points
    self.data = data
    N = data.shape[0]
    self.k = k
    self.numtrees = trees
    
    # A 2D tensor representing a directed graph.
    # The value a = graph[i,j] represents an edge from point x_i to x_a.
    self.graph = torch.zeros(size=[N, k], dtype=torch.long)
    
    # Initialize graph randomly or with forest
    if random:
      self.initialize_graph_randomly()
    else:
      self.initialize_graph_forest(data)

    # A set of tuples (i,j) of indices for which the distance has already been calculated.
    self.explored_edges = set()

    # A 2D tensor representing the distance between point x_i and x_graph[i,j]
    self.k_distances = torch.zeros([N, k])
    self.calculate_all_distances()
    

  def initialize_graph_randomly(self):
    '''
    Initializes self.graph with random values such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    # Initialize graph randomly, removing self-loops
    self.graph = torch.randint(high = N-1, size=[N,k], dtype=torch.long)
    row_indices = torch.arange(N).unsqueeze(1).repeat(1,k)
    self.graph[self.graph>=row_indices] += 1

  def initialize_graph_forest(self, data):
    '''
    Initializes self.graph with a forest of random trees, such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    temp_graph = torch.tensor([])
        
    # make 'trees', combine into giant graph with each element (row) having k * num_trees neighbours
    # this is a small for loop - numtrees and k << datapoints
    for j in range(self.numtrees):
      tree_graph = torch.tensor([])
      for i in range(k):
        tree_graph = torch.cat((tree_graph,torch.randperm(N)),0) # generate randomly shuffled list of N indices
      tree_graph = tree_graph.reshape(-1,k) # creates a N x k tensor with N indices, each appearing k times. This represents 1 'tree'
      temp_graph = torch.cat((temp_graph,tree_graph),1) # combine into giant N x (k*num_trees) tensor. This represents the forest
    
    # find KNN for each row in giant graph
    # TODO - implement the below without a for loop
    for i, row in enumerate(temp_graph):
      temp_row = torch.unique(row).type(torch.LongTensor) # remove duplicates
      temp_row = temp_row[temp_row != i] # remove self
      
      temp_points = data[temp_row,:] # pick out elements from dataset
      distances = dist_bulk(temp_points,data[i]) # Euclidean distances
      indices = distances.topk(k=self.k, largest=False).indices # find indices of KNN
      self.graph[i] = temp_row[indices] # assign KNN to graph
      
#     N, k = self.graph.shape
#     temp_graph = torch.zeros(size=[N, k*self.numtrees], dtype=torch.long) # make a N x (k*numtrees) matrix
    
#     for i, row in enumerate(self.graph):
      
#       # creating [numtrees] 'trees' for 1 row (i.e. datapoint)
#       temp_row = torch.tensor([])
#       for j in range(self.numtrees):
#         random_row = torch.randperm(N-1)[:k] # k random index values without replacement
#         random_row[random_row >= i] += 1 # excluding i to avoid loops in graph
#         temp_row = torch.cat((temp_row, random_row), 0) # combine random indices from all 'trees'
      
#       temp_row = torch.unique(temp_row).type(torch.LongTensor) # remove duplicates
#       temp_points = data[temp_row,:] # pick out elements from dataset
#       distances = dist_bulk(temp_points,data[i]) # Euclidean distances
#       indices = distances.topk(k=self.k, largest=False).indices # find indices of KNN
#       self.graph[i] = temp_row[indices] # assign KNN to graph
        
  def calculate_all_distances(self):
    '''
    Updates the distances (self.k_distances) of the edges found in self.graph.
    '''
    # Note: Start with for loop for simplicity. TODO: Try to remove loop.
    for i, row in enumerate(self.graph):
      # Indices of current k neighbors in self.graph
      neighbor_indices = [(i,int(r)) for r in row]

      # The distances of those neighbors are saved in k_distances
      self.k_distances[i] = torch.Tensor([dist(self.data[a],self.data[b]) for a,b in neighbor_indices])

      # Add pairs to explored_edges set
      self.explored_edges.update(neighbor_indices) 
    

  def update_graph(self, iter=5):
    '''
      Updates the graph using algorithm: https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html
    '''
    # [STEP 1: Start with random graph.] Iterate
    start = time.time()
    for it in range(iter):
      print("Iteration number",it,"with average distance of",torch.mean(self.k_distances).item(),"Took", time.time()-start,"seconds.")
      has_changed = False

      # [STEP 2: For each node:] (TODO: Investigate whether this can be vectorized.)
      for i, neighbors in enumerate(self.graph):
        # Distances of current neighbors
        dist_current_neighbors = self.k_distances[i]

        # [STEP 3: Measure distance from the node to the neighbors of its neighbors]
        # Find neighbors of neighbors
        potential_neighbors = {a.item() for a in self.graph[neighbors].flatten() \
                               if a not in neighbors and a!=i and (i,int(a)) not in self.explored_edges}
        potential_distances = torch.Tensor([dist(self.data[i],self.data[n]) for n in potential_neighbors])
        self.explored_edges.update([(i,int(r)) for r in potential_neighbors])

        # Concatenate potential neighbors to list of neighbors (indices and distances)
        cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
        cat_dist = torch.cat([self.k_distances[i], potential_distances])

        # [STEP 4: If any are closer, then update the graph accordingly, and only keep the k closest]
        # Sort using torch.sort(), which also returns sorted indices
        dist_sorted, idx = torch.sort(cat_dist)
        if torch.max(idx[:self.k]) >= self.k:
          has_changed = True
          self.graph[i] = cat_idx[idx[:self.k]]
          self.k_distances[i] = dist_sorted[:self.k]
        
      # [STEP 5: If any changes were made, repeat iteration, otherwise stop]
      if not has_changed:
        print("Nothing changed in iteration",it)
        break
    print("Done.")

  def k_nearest_graph_search(self,x):
    '''
    Gets the k nearest neighbors of input x according to the graph, using this algorithm:
      https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html#Searching-using-a-nearest-neighbor-graph
    Input: 
      x - a torch tensor of shape (1,d), where d is the number of dimentions of the input data.
      TODO: Add support for batches, so x can be of shape (n, d). Not sure how to batch-ify the graph search...
    Output:
      The indices of the k nearest neighbors according using graph search with random initialization.
    '''

    # Random initialization for starting point of search
    random_start = torch.randint(len(self.data),size=[1], dtype=torch.long)
    # A list of random initialization and its neighbors
    candidate_idx = torch.cat([self.graph[random_start].squeeze(),random_start], dim=0) 
    # Track the nodes we have explored already
    explored = [random_start]
    
    count = 0
    while True:
      count += 1
      # [2. Look at nodes connected by an edge to the best untried node in graph]
      unexplored = [i for i in candidate_idx if i not in explored]
      if not unexplored:
        # if nothing is unexplored, the search is over
        break 
    
      # Add neighbors of the first unexplored point to the list of candidates
      candidate_idx = torch.cat([self.graph[unexplored[0]],candidate_idx], dim=0) 
      # and mark it as explored
      explored = explored + [unexplored[0]]
    
      # [4. Sort by closeness]
      distances = ((self.data[candidate_idx].unsqueeze(1) - x.unsqueeze(0))**2).sum(-1).squeeze() 
      sorted, idx = torch.sort(distances,dim=0)
      candidate_idx = candidate_idx[idx]
    
      # [5. Truncate to k best]
      # TODO: use Hudson's unwanted_indices method from "Loop replacement.ipynb" to get unique values.
      # It might be faster and is probably more useful when algo supports predictions for batches.
      candidate_idx = torch.unique_consecutive(candidate_idx)[:k] # unique_consecutive doesn't sort
    
      # [6. Return to step 2. If we have already tried all candidates in pool, we stop in the if not unexplored]
    
    # Return the k candidates
    print("Graph search finished after",count,"steps")
    return candidate_idx

  def predict(self,x):
    '''
    Predict output using tree. Hasn't been implemented yet. Needs labels y.
    '''
    pass
    
def dist(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  return torch.sum((x-y)**2)

def dist_bulk(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  # For datasets
  return ((x-y)**2).sum(-1)


In [69]:
# Testing out NNDescent class
# torch.manual_seed(1)
data = torch.Tensor([[1.0,1.0], [2.0,1.0], [3.0,1.0], [4.0,1.0],
                     [1.0,2.0], [2.0,2.0], [3.0,2.0], [4.0,2.0]])  
data = torch.randn(size=[10000,4])
# print(data)  

k = 8

# Initialize NNDescent graph randomly
'''
print("Initializing randomly...")
start = time.time()
n = NNDescent(data, k=k, random=True)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)
'''


# Initialize NNDescent graph with forest
print("Initializing forest...")
start = time.time()
n = NNDescent(data, k=k, random=False)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)

# Brute force search
print("Brute Force")
start = time.time()
m=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)+torch.Tensor([float('inf')]).repeat(len(data)).diag() # Infinity is added to diagonal
brute_force = torch.topk(-m,k=k,dim=1)[1]
print("Took", time.time()-start,"seconds.\n")
print(brute_force)

# Get k nearest neighbors using graph search
print()
print("Graph search")
#x = torch.Tensor([4, 0]).unsqueeze(0) # size 1 x d
x = torch.randn(size=[1,data.shape[1]])
print("x is",x)
start = time.time()
k_nearest = n.k_nearest_graph_search(x)
print("Took", time.time()-start,"seconds.")
print("Nearest indices with graph search:",k_nearest)
print("coordinates of nearest dots:",n.data[k_nearest])

# Get k nearest using brute force knn
print("\nActual nearest using KNN")
start = time.time()
m=((data.unsqueeze(1)-x.unsqueeze(0))**2).sum(-1).squeeze() # Infinity is added to diagonal
print(m)
brute_force = torch.sort(m,dim=0)[1][:k]
print("Took", time.time()-start,"seconds.")
print("The KNN nearest are:",brute_force)
print("coordinates of nearest dots:",n.data[brute_force])

Initializing forest...
Took 3.5460469722747803 seconds.

Graph:
tensor([[1564, 4474, 1956,  ..., 1435, 5476, 5240],
        [ 513, 4605, 7101,  ..., 8732, 2244,  456],
        [8865, 7906, 2213,  ..., 5828, 9066, 3063],
        ...,
        [8504, 8495, 3966,  ..., 9641, 7798, 2392],
        [1399,  994, 9306,  ..., 4866, 5779, 3124],
        [2643, 4588, 6014,  ..., 8957,  965, 9367]])
Distances:
tensor([[1.2512, 1.4119, 1.6376,  ..., 1.9776, 2.0939, 2.1132],
        [0.5113, 0.7892, 0.8453,  ..., 1.1319, 1.1671, 1.2418],
        [0.8918, 0.9721, 0.9884,  ..., 1.1155, 1.1532, 1.1989],
        ...,
        [0.8284, 0.8757, 1.1301,  ..., 1.3635, 1.3636, 1.4027],
        [0.6265, 0.7075, 0.7634,  ..., 1.1946, 1.2185, 1.2236],
        [1.4927, 1.5737, 1.7241,  ..., 1.9100, 1.9203, 1.9472]])
Updating...

Iteration number 0 with average distance of 1.911115288734436 Took 0.0004134178161621094 seconds.
Iteration number 1 with average distance of 0.8456408381462097 Took 33.46454191207886 seco

In [None]:
# Sandbox for testing
i = 0
neighbors = n.graph[i]
print("Neighbors:",neighbors)
print("Neighbors of neighbors:",n.graph[neighbors].flatten())
potential_neighbors = {a.item() for a in n.graph[neighbors].flatten() if a not in neighbors and a!=i}
print("New potential neighbors:", potential_neighbors)
potential_distances = torch.Tensor([dist(data[i],data[n]) for n in potential_neighbors])
print("Potential distances:", (potential_distances))
cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
cat_dist = torch.cat([n.k_distances[i], potential_distances])
print("cat_idx:", cat_idx)
print("cat_dist:", cat_dist)
print("sort cat:", torch.sort(cat_dist))
val, idx = torch.sort(cat_dist)
print("idx max", torch.max(idx)>3)
print()
print("New neighbors:",cat_idx[idx[:3]])
print("New distances:",val[:3])

NameError: name 'n' is not defined