Rewritten NNDescent algorithm using matrices instead of dictionaries.

TODO: Remove for-loops where possible.

In [75]:
import torch
import time
#import matplotlib.pyplot as plt
#from collections import defaultdict

class NNDescent:
  def __init__(self, data, k=3, trees=10, random=False):
    # The data is a (N x d) matrix with N instances of d-dimensional points
    self.data = data
    N = data.shape[0]
    self.k = k
    self.numtrees = trees
    
    # A 2D tensor representing a directed graph.
    # The value a = graph[i,j] represents an edge from point x_i to x_a.
    self.graph = torch.zeros(size=[N, k], dtype=torch.long)
    
    # Initialize graph randomly or with forest
    if random:
      self.initialize_graph_randomly()
    else:
      self.initialize_graph_forest(data)

    # A set of tuples (i,j) of indices for which the distance has already been calculated.
    self.explored_edges = set()

    # A 2D tensor representing the distance between point x_i and x_graph[i,j]
    self.k_distances = torch.zeros([N, k])
    self.calculate_all_distances()
    

  def initialize_graph_randomly(self):
    '''
    Initializes self.graph with random values such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    for i, row in enumerate(self.graph):
      random_row = torch.randperm(N-1)[:k] # k random values without replacement
      random_row[random_row >= i] += 1 # excluding i to avoid loops in graph
      self.graph[i] = random_row

  def initialize_graph_forest(self, data):
    '''
    Initializes self.graph with a forest of random trees, such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    temp_graph = torch.tensor([])
        
    # make 'trees', combine into giant graph with each element (row) having k * num_trees neighbours
    # this is a small for loop - numtrees and k << datapoints
    for j in range(self.numtrees):
      tree_graph = torch.tensor([])
      for i in range(k):
        tree_graph = torch.cat((tree_graph,torch.randperm(N)),0) # generate randomly shuffled list of N indices
      tree_graph = tree_graph.reshape(-1,k) # creates a N x k tensor with N indices, each appearing k times. This represents 1 'tree'
      temp_graph = torch.cat((temp_graph,tree_graph),1) # combine into giant N x (k*num_trees) tensor. This represents the forest
    
    # find KNN for each row in giant graph
    # TODO - implement the below without a for loop
    for i, row in enumerate(temp_graph):
      temp_row = torch.unique(row).type(torch.LongTensor) # remove duplicates
      temp_row = temp_row[temp_row != i] # remove self
      
      temp_points = data[temp_row,:] # pick out elements from dataset
      distances = dist_bulk(temp_points,data[i]) # Euclidean distances
      indices = distances.topk(k=self.k, largest=False).indices # find indices of KNN
      self.graph[i] = temp_row[indices] # assign KNN to graph
      
#     N, k = self.graph.shape
#     temp_graph = torch.zeros(size=[N, k*self.numtrees], dtype=torch.long) # make a N x (k*numtrees) matrix
    
#     for i, row in enumerate(self.graph):
      
#       # creating [numtrees] 'trees' for 1 row (i.e. datapoint)
#       temp_row = torch.tensor([])
#       for j in range(self.numtrees):
#         random_row = torch.randperm(N-1)[:k] # k random index values without replacement
#         random_row[random_row >= i] += 1 # excluding i to avoid loops in graph
#         temp_row = torch.cat((temp_row, random_row), 0) # combine random indices from all 'trees'
      
#       temp_row = torch.unique(temp_row).type(torch.LongTensor) # remove duplicates
#       temp_points = data[temp_row,:] # pick out elements from dataset
#       distances = dist_bulk(temp_points,data[i]) # Euclidean distances
#       indices = distances.topk(k=self.k, largest=False).indices # find indices of KNN
#       self.graph[i] = temp_row[indices] # assign KNN to graph
        
  def calculate_all_distances(self):
    '''
    Updates the distances (self.k_distances) of the edges found in self.graph.
    '''
    # Note: Start with for loop for simplicity. TODO: Try to remove loop.
    for i, row in enumerate(self.graph):
      # Indices of current k neighbors in self.graph
      neighbor_indices = [(i,int(r)) for r in row]

      # The distances of those neighbors are saved in k_distances
      self.k_distances[i] = torch.Tensor([dist(self.data[a],self.data[b]) for a,b in neighbor_indices])

      # Add pairs to explored_edges set
      self.explored_edges.update(neighbor_indices) 
    

  def update_graph(self, iter=5):
    '''
      Updates the graph using algorithm: https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html
    '''
    # [STEP 1: Start with random graph.] Iterate
    start = time.time()
    for it in range(iter):
      print("Iteration number",it,"with average distance of",torch.mean(self.k_distances).item(),"Took", time.time()-start,"seconds.\n")
      has_changed = False

      # [STEP 2: For each node:] (TODO: Investigate whether this can be vectorized.)
      for i, neighbors in enumerate(self.graph):
        # Distances of current neighbors
        dist_current_neighbors = self.k_distances[i]

        # [STEP 3: Measure distance from the node to the neighbors of its neighbors]
        # Find neighbors of neighbors
        potential_neighbors = {a.item() for a in self.graph[neighbors].flatten() \
                               if a not in neighbors and a!=i and (i,int(a)) not in self.explored_edges}
        potential_distances = torch.Tensor([dist(self.data[i],self.data[n]) for n in potential_neighbors])
        self.explored_edges.update([(i,int(r)) for r in potential_neighbors])

        # Concatenate potential neighbors to list of neighbors (indices and distances)
        cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
        cat_dist = torch.cat([self.k_distances[i], potential_distances])

        # [STEP 4: If any are closer, then update the graph accordingly, and only keep the k closest]
        # Sort using torch.sort(), which also returns sorted indices
        dist_sorted, idx = torch.sort(cat_dist)
        if torch.max(idx[:self.k]) >= self.k:
          has_changed = True
          self.graph[i] = cat_idx[idx[:self.k]]
          self.k_distances[i] = dist_sorted[:self.k]
        
      # [STEP 5: If any changes were made, repeat iteration, otherwise stop]
      if not has_changed:
        print("Nothing changed in iteration",it)
        break
    print("Done.")

    def predict(self,x):
      '''
      Predict output using tree. Hasn't been implemented yet.
      '''
      pass

def dist(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  return torch.sum((x-y)**2)

def dist_bulk(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  # For datasets
  return ((x-y)**2).sum(-1)


In [76]:
# Testing out NNDescent class
# torch.manual_seed(1)
data = torch.Tensor([[1.0,1.0], [2.0,1.0], [3.0,1.0], [4.0,1.0],
                     [1.0,2.0], [2.0,2.0], [3.0,2.0], [4.0,2.0]])  
data = torch.randn(size=[10000,4])
# print(data)  

# Initialize NNDescent graph randomly
print("Initializing...")
start = time.time()
n = NNDescent(data, k=3, random=True)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)


# Initialize NNDescent graph with forest
print("Initializing...")
start = time.time()
n = NNDescent(data, k=3, random=False)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)

Initializing...
Took 4.5824549198150635 seconds.

Graph:
tensor([[4171, 3724, 3449],
        [6047, 9386,  930],
        [6311, 5962, 2761],
        ...,
        [4886, 2380, 9333],
        [  12,   43, 2763],
        [8291, 3449, 6722]])
Distances:
tensor([[3.3938, 2.1097, 1.9777],
        [1.8369, 2.4021, 2.4113],
        [2.5247, 4.8583, 2.5016],
        ...,
        [4.2610, 2.6421, 3.7299],
        [2.0635, 2.2910, 3.0519],
        [1.7187, 2.3046, 2.4228]])
Updating...

Iteration number 0 with average distance of 7.9627532958984375 Took 0.0 seconds.

Iteration number 1 with average distance of 3.148834705352783 Took 7.721186876296997 seconds.

Iteration number 2 with average distance of 2.0975418090820312 Took 18.857218027114868 seconds.

Iteration number 3 with average distance of 1.6818013191223145 Took 28.279463529586792 seconds.

Iteration number 4 with average distance of 1.5095971822738647 Took 36.91353988647461 seconds.

Iteration number 5 with average distance of 1.434106

In [192]:
# Sandbox for testing
i = 0
neighbors = n.graph[i]
print("Neighbors:",neighbors)
print("Neighbors of neighbors:",n.graph[neighbors].flatten())
potential_neighbors = {a.item() for a in n.graph[neighbors].flatten() if a not in neighbors and a!=i}
print("New potential neighbors:", potential_neighbors)
potential_distances = torch.Tensor([dist(data[i],data[n]) for n in potential_neighbors])
print("Potential distances:", (potential_distances))
cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
cat_dist = torch.cat([n.k_distances[i], potential_distances])
print("cat_idx:", cat_idx)
print("cat_dist:", cat_dist)
print("sort cat:", torch.sort(cat_dist))
val, idx = torch.sort(cat_dist)
print("idx max", torch.max(idx)>3)
print()
print("New neighbors:",cat_idx[idx[:3]])
print("New distances:",val[:3])

NameError: name 'n' is not defined