Rewritten NNDescent algorithm using matrices instead of dictionaries.

TODO: Remove for-loops where possible.

In [5]:
import torch
import time
#import matplotlib.pyplot as plt
#from collections import defaultdict

class NNDescent:
  def __init__(self, data, k=3):
    # The data is a (N x d) matrix with N instances of d-dimensional points
    self.data = data
    N = data.shape[0]
    self.k = k
    
    # A 2D tensor representing a directed graph.
    # The value a = graph[i,j] represents an edge from point x_i to x_a.
    self.graph = torch.zeros(size=[N, k], dtype=torch.long)
    self.initialize_graph_randomly()

    # A set of tuples (i,j) of indices for which the distance has already been calculated.
    self.explored_edges = set()

    # A 2D tensor representing the distance between point x_i and x_graph[i,j]
    self.k_distances = torch.zeros([N, k])
    self.calculate_all_distances()
    

  def initialize_graph_randomly(self):
    '''
    Initializes self.graph with random values such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    for i, row in enumerate(self.graph):
      random_row = torch.randperm(N-1)[:k] # k random values without replacement
      random_row[random_row >= i] += 1 # excluding i to avoid loops in graph
      self.graph[i] = random_row

  def calculate_all_distances(self):
    '''
    Updates the distances (self.k_distances) of the edges found in self.graph.
    '''
    # Note: Start with for loop for simplicity. TODO: Try to remove loop.
    for i, row in enumerate(self.graph):
      # Indices of current k neighbors in self.graph
      neighbor_indices = [(i,int(r)) for r in row]

      # The distances of those neighbors are saved in k_distances
      self.k_distances[i] = torch.Tensor([dist(self.data[a],self.data[b]) for a,b in neighbor_indices])

      # Add pairs to explored_edges set
      self.explored_edges.update(neighbor_indices) 
    

  def update_graph(self, iter=5):
    '''
      Updates the graph using algorithm: https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html
    '''
    # [STEP 1: Start with random graph.] Iterate
    for it in range(iter):
      print("Iteration number",it,"with average distance of",torch.mean(self.k_distances).item())
      has_changed = False

      # [STEP 2: For each node:] (TODO: Investigate whether this can be vectorized.)
      for i, neighbors in enumerate(self.graph):
        # Distances of current neighbors
        dist_current_neighbors = self.k_distances[i]

        # [STEP 3: Measure distance from the node to the neighbors of its neighbors]
        # Find neighbors of neighbors
        potential_neighbors = {a.item() for a in self.graph[neighbors].flatten() \
                               if a not in neighbors and a!=i and (i,int(a)) not in self.explored_edges}
        potential_distances = torch.Tensor([dist(self.data[i],self.data[n]) for n in potential_neighbors])
        self.explored_edges.update([(i,int(r)) for r in potential_neighbors])

        # Concatenate potential neighbors to list of neighbors (indices and distances)
        cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
        cat_dist = torch.cat([self.k_distances[i], potential_distances])

        # [STEP 4: If any are closer, then update the graph accordingly, and only keep the k closest]
        # Sort using torch.sort(), which also returns sorted indices
        dist_sorted, idx = torch.sort(cat_dist)
        if torch.max(idx[:self.k]) >= self.k:
          has_changed = True
          self.graph[i] = cat_idx[idx[:self.k]]
          self.k_distances[i] = dist_sorted[:self.k]
        
      # [STEP 5: If any changes were made, repeat iteration, otherwise stop]
      if not has_changed:
        print("Nothing changed in iteration",it)
        break
    print("Done.")

    def predict(self,x):
      '''
      Predict output using tree. Hasn't been implemented yet.
      '''
      pass


def dist(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  return torch.sum((x-y)**2)



In [6]:
# Testing out NNDescent class
torch.manual_seed(1)
data = torch.Tensor([[1.0,1.0], [2.0,1.0], [3.0,1.0], [4.0,1.0],
                     [1.0,2.0], [2.0,2.0], [3.0,2.0], [4.0,2.0]])  
#data = torch.randn(size=[1000,4])
print(data)  

# Initialize NNDescent graph
n = NNDescent(data, k=3)
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))

print("Updating...\n")
start = time.time()
n.update_graph(iter=25)
print("Took", time.time()-start,"seconds.\n")
print("Graph:")
print(n.graph)
print("Distances:")
print(torch.sqrt(n.k_distances))
#print(n.k_distances)


tensor([[1., 1.],
        [2., 1.],
        [3., 1.],
        [4., 1.],
        [1., 2.],
        [2., 2.],
        [3., 2.],
        [4., 2.]])
Graph:
tensor([[1, 7, 2],
        [0, 7, 2],
        [6, 7, 0],
        [5, 6, 2],
        [2, 1, 7],
        [0, 3, 4],
        [4, 1, 0],
        [1, 6, 5]])
Distances:
tensor([[1.0000, 3.1623, 2.0000],
        [1.0000, 2.2361, 1.0000],
        [1.0000, 1.4142, 2.0000],
        [2.2361, 1.4142, 1.0000],
        [2.2361, 1.4142, 3.0000],
        [1.4142, 2.2361, 1.0000],
        [2.0000, 1.4142, 2.2361],
        [2.2361, 1.0000, 2.0000]])
Updating...

Iteration number 0 with average distance of 3.4166667461395264
Iteration number 1 with average distance of 1.5
Iteration number 2 with average distance of 1.3333333730697632
Nothing changed in iteration 2
Done.
Took 0.012814998626708984 seconds.

Graph:
tensor([[1, 4, 5],
        [0, 2, 5],
        [6, 1, 7],
        [2, 7, 6],
        [0, 5, 1],
        [4, 1, 6],
        [2, 5, 7],
        [6,

In [5]:
# Sandbox for testing
i = 0
neighbors = n.graph[i]
print("Neighbors:",neighbors)
print("Neighbors of neighbors:",n.graph[neighbors].flatten())
potential_neighbors = {a.item() for a in n.graph[neighbors].flatten() if a not in neighbors and a!=i}
print("New potential neighbors:", potential_neighbors)
potential_distances = torch.Tensor([dist(data[i],data[n]) for n in potential_neighbors])
print("Potential distances:", (potential_distances))
cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
cat_dist = torch.cat([n.k_distances[i], potential_distances])
print("cat_idx:", cat_idx)
print("cat_dist:", cat_dist)
print("sort cat:", torch.sort(cat_dist))
val, idx = torch.sort(cat_dist)
print("idx max", torch.max(idx)>3)
print()
print("New neighbors:",cat_idx[idx[:3]])
print("New distances:",val[:3])

Neighbors: tensor([1, 4, 5])
Neighbors of neighbors: tensor([0, 2, 5, 0, 5, 1, 4, 1, 6])
New potential neighbors: {2, 6}
Potential distances: tensor([4., 5.])
cat_idx: tensor([1., 4., 5., 2., 6.])
cat_dist: tensor([1., 1., 2., 4., 5.])
sort cat: torch.return_types.sort(
values=tensor([1., 1., 2., 4., 5.]),
indices=tensor([0, 1, 2, 3, 4]))
idx max tensor(True)

New neighbors: tensor([1., 4., 5.])
New distances: tensor([1., 1., 2.])
