In [1]:
import torch
import time
#import matplotlib.pyplot as plt
#from collections import defaultdict

class NNDescent:
  def __init__(self, data, k=3):
    # The data is a (N x d) matrix with N instances of d-dimensional points
    self.data = data
    N = data.shape[0]
    self.k = k
    
    # A 2D tensor representing a directed graph.
    # The value a = graph[i,j] represents an edge from point x_i to x_a.
    self.graph = torch.zeros(size=[N, k], dtype=torch.long)
    self.initialize_graph_randomly()

    # A set of tuples (i,j) of indices for which the distance has already been calculated.
    self.explored_edges = set()

    # A 2D tensor representing the distance between point x_i and x_graph[i,j]
    self.k_distances = torch.zeros([N, k])
    self.calculate_all_distances()
    

  def initialize_graph_randomly(self):
    '''
    Initializes self.graph with random values such that each point has k distinct neighbors
    '''
    N, k = self.graph.shape
    for i, row in enumerate(self.graph):
      random_row = torch.randperm(N-1)[:k] # k random values without replacement
      random_row[random_row >= i] += 1 # excluding i to avoid loops in graph
      self.graph[i] = random_row

  def calculate_all_distances(self):
    '''
    Updates the distances (self.k_distances) of the edges found in self.graph.
    '''
    # Note: Start with for loop for simplicity. TODO: Try to remove loop.
    for i, row in enumerate(self.graph):
      # Indices of current k neighbors in self.graph
      neighbor_indices = [(i,int(r)) for r in row]

      # The distances of those neighbors are saved in k_distances
      self.k_distances[i] = torch.Tensor([dist(self.data[a],self.data[b]) for a,b in neighbor_indices])

      # Add pairs to explored_edges set
      self.explored_edges.update(neighbor_indices) 
    

  def update_graph(self, iter=5):
    '''
      Updates the graph using algorithm: https://pynndescent.readthedocs.io/en/latest/how_pynndescent_works.html
    '''
    # [STEP 1: Start with random graph.] Iterate
    for it in range(iter):
      print("Iteration number",it,"with average distance of",torch.mean(self.k_distances).item())
      has_changed = False

      # [STEP 2: For each node:] (TODO: Investigate whether this can be vectorized.)
      for i, neighbors in enumerate(self.graph):
        # Distances of current neighbors
        dist_current_neighbors = self.k_distances[i]

        # [STEP 3: Measure distance from the node to the neighbors of its neighbors]
        # Find neighbors of neighbors
        potential_neighbors = {a.item() for a in self.graph[neighbors].flatten() \
                               if a not in neighbors and a!=i and (i,int(a)) not in self.explored_edges}
        potential_distances = torch.Tensor([dist(self.data[i],self.data[n]) for n in potential_neighbors])
        self.explored_edges.update([(i,int(r)) for r in potential_neighbors])

        # Concatenate potential neighbors to list of neighbors (indices and distances)
        cat_idx = torch.cat([neighbors, torch.Tensor(list(potential_neighbors))])
        cat_dist = torch.cat([self.k_distances[i], potential_distances])

        # [STEP 4: If any are closer, then update the graph accordingly, and only keep the k closest]
        # Sort using torch.sort(), which also returns sorted indices
        dist_sorted, idx = torch.sort(cat_dist)
        if torch.max(idx[:self.k]) >= self.k:
          has_changed = True
          self.graph[i] = cat_idx[idx[:self.k]]
          self.k_distances[i] = dist_sorted[:self.k]
        
      # [STEP 5: If any changes were made, repeat iteration, otherwise stop]
      if not has_changed:
        print("Nothing changed in iteration",it)
        break
    print("Done.")

    def predict(self,x):
      '''
      Predict output using tree. Hasn't been implemented yet.
      '''
      pass


def dist(x,y):
  # Square of euclidian distance. Skip the root for faster computation.
  return torch.sum((x-y)**2)



In [2]:
k=3
d=2
torch.manual_seed(1)
data = torch.Tensor([[1.0,1.0], [2.0,1.0], [3.0,1.0], [4.0,1.0],
                     [1.0,2.0], [2.0,2.0], [3.0,2.0], [4.0,2.0]])  
n = NNDescent(data, k=k)
n.graph

tensor([[1, 7, 2],
        [0, 7, 2],
        [6, 7, 0],
        [5, 6, 2],
        [2, 1, 7],
        [0, 3, 4],
        [4, 1, 0],
        [1, 6, 5]])

In [3]:
n.update_graph()

Iteration number 0 with average distance of 3.4166667461395264
Iteration number 1 with average distance of 1.5
Iteration number 2 with average distance of 1.3333333730697632
Nothing changed in iteration 2
Done.


In [4]:
n.graph

tensor([[1, 4, 5],
        [0, 2, 5],
        [6, 1, 7],
        [7, 2, 6],
        [0, 5, 1],
        [1, 6, 4],
        [2, 5, 7],
        [6, 2, 5]])

In [5]:
#print the true K argmin
m=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)+torch.eye(len(data))*100
torch.topk(-m,k=3,dim=1)[1]

tensor([[4, 1, 5],
        [0, 5, 2],
        [1, 6, 3],
        [7, 2, 6],
        [0, 5, 1],
        [4, 1, 6],
        [7, 2, 5],
        [3, 6, 2]])

In [7]:
#code to iterate one more time

#get the neighbours of the first value in each row of n.graph
new_neighbours=torch.index_select(n.graph,0,n.graph[:,0]) 
new_neighbours 

tensor([[0, 2, 5],
        [1, 4, 5],
        [2, 5, 7],
        [6, 2, 5],
        [1, 4, 5],
        [0, 2, 5],
        [6, 1, 7],
        [2, 5, 7]])

In [8]:
old_neighbours=n.graph
old_neighbours #current neighbours in the nn graph

tensor([[1, 4, 5],
        [0, 2, 5],
        [6, 1, 7],
        [7, 2, 6],
        [0, 5, 1],
        [1, 6, 4],
        [2, 5, 7],
        [6, 2, 5]])

In [9]:
nlist=torch.sort(torch.cat((old_neighbours,new_neighbours),dim=1),dim=1)[0]
nlist #adds old and new neighbours, and sorts them

tensor([[0, 1, 2, 4, 5, 5],
        [0, 1, 2, 4, 5, 5],
        [1, 2, 5, 6, 7, 7],
        [2, 2, 5, 6, 6, 7],
        [0, 1, 1, 4, 5, 5],
        [0, 1, 2, 4, 5, 6],
        [1, 2, 5, 6, 7, 7],
        [2, 2, 5, 5, 6, 7]])

In [10]:
shift=torch.cat((torch.full((len(nlist),1),-1),torch.sort(nlist,dim=1)[0][:,:-1]),dim=1)
unwanted_indices=nlist==shift
unwanted_indices #obtains the location of repeated values (for the repeated locations only!)

tensor([[False, False, False, False, False,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False,  True],
        [False,  True, False, False,  True, False],
        [False, False,  True, False, False,  True],
        [False, False, False, False, False, False],
        [False, False, False, False, False,  True],
        [False,  True, False,  True, False, False]])

In [11]:
#include this line if you want to exclude the same index for each row, ie 0 is not a neighbour of 0
same_index=torch.arange(len(nlist)).repeat(2*k,1).T==nlist #boolean mask
unwanted_indices=torch.logical_or(same_index,unwanted_indices) #take or to get mask
unwanted_indices

tensor([[ True, False, False, False, False,  True],
        [False,  True, False, False, False,  True],
        [False,  True, False, False, False,  True],
        [False,  True, False, False,  True, False],
        [False, False,  True,  True, False,  True],
        [False, False, False, False,  True, False],
        [False, False, False,  True, False,  True],
        [False,  True, False,  True, False,  True]])

In [18]:
neighbours_data=torch.index_select(data,0,nlist.flatten()).reshape(*nlist.shape,-1)
neighbours_data #obtains all the values of the new 2k number of neighbours

tensor([[[1., 1.],
         [2., 1.],
         [3., 1.],
         [1., 2.],
         [2., 2.],
         [2., 2.]],

        [[1., 1.],
         [2., 1.],
         [3., 1.],
         [1., 2.],
         [2., 2.],
         [2., 2.]],

        [[2., 1.],
         [3., 1.],
         [2., 2.],
         [3., 2.],
         [4., 2.],
         [4., 2.]],

        [[3., 1.],
         [3., 1.],
         [2., 2.],
         [3., 2.],
         [3., 2.],
         [4., 2.]],

        [[1., 1.],
         [2., 1.],
         [2., 1.],
         [1., 2.],
         [2., 2.],
         [2., 2.]],

        [[1., 1.],
         [2., 1.],
         [3., 1.],
         [1., 2.],
         [2., 2.],
         [3., 2.]],

        [[2., 1.],
         [3., 1.],
         [2., 2.],
         [3., 2.],
         [4., 2.],
         [4., 2.]],

        [[3., 1.],
         [3., 1.],
         [2., 2.],
         [2., 2.],
         [3., 2.],
         [4., 2.]]])

In [16]:
reshaped_data=torch.transpose(data.repeat(2*k,1).reshape(2*k,len(nlist),d),0,1)
reshaped_data #taking the data but repeating it 2k times so we can perform distance computation

tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[2., 1.],
         [2., 1.],
         [2., 1.],
         [2., 1.],
         [2., 1.],
         [2., 1.]],

        [[3., 1.],
         [3., 1.],
         [3., 1.],
         [3., 1.],
         [3., 1.],
         [3., 1.]],

        [[4., 1.],
         [4., 1.],
         [4., 1.],
         [4., 1.],
         [4., 1.],
         [4., 1.]],

        [[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]],

        [[2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.]],

        [[3., 2.],
         [3., 2.],
         [3., 2.],
         [3., 2.],
         [3., 2.],
         [3., 2.]],

        [[4., 2.],
         [4., 2.],
         [4., 2.],
         [4., 2.],
         [4., 2.],
         [4., 2.]]])

In [19]:
dist=((reshaped_data-neighbours_data)**2).sum(-1) #calculate the distance
dist[unwanted_indices]=float("Inf") #set the distance at the unwanted indices to infinity

In [22]:
sort,idx=torch.sort(dist,dim=1)
#sort distances and get indices

In [15]:
#use gather to obtain the final values
nn=torch.gather(nlist,dim=1,index=idx)[:,:k]
nn

tensor([[1, 4, 5],
        [0, 2, 5],
        [1, 6, 5],
        [2, 7, 6],
        [0, 5, 1],
        [1, 4, 6],
        [2, 5, 7],
        [6, 2, 5]])

In [16]:
n.graph #no update since the algo terminated already

tensor([[1, 4, 5],
        [0, 2, 5],
        [6, 1, 7],
        [7, 2, 6],
        [0, 5, 1],
        [1, 6, 4],
        [2, 5, 7],
        [6, 2, 5]])

In [20]:
dist

tensor([[inf, 1., 4., 1., 2., inf],
        [1., inf, 1., 2., 1., inf],
        [1., inf, 2., 1., 2., inf],
        [1., inf, 5., 2., inf, 1.],
        [1., 2., inf, inf, 1., inf],
        [2., 1., 2., 1., inf, 1.],
        [2., 1., 1., inf, 1., inf],
        [2., inf, 4., inf, 1., inf]])

In [23]:
torch.gather(dist,dim=1,index=idx)[:,:k]

tensor([[1., 1., 2.],
        [1., 1., 1.],
        [1., 1., 2.],
        [1., 1., 2.],
        [1., 1., 2.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 2., 4.]])

In [94]:
#updating KNN graph from leaf code for Yihang

k=5
leaf=torch.tensor([0,1,2,3,4])#.reshape(5,1)
nn=torch.zeros([2*k,k-1])

In [95]:
#goal is to update the first 5 rows of the NN graph
nn

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [96]:
#NN update
idx_update=torch.as_strided(leaf.repeat(1,2),size=[k,k-1],stride=[1,1],storage_offset=1)
idx_update

tensor([[1, 2, 3, 4],
        [2, 3, 4, 0],
        [3, 4, 0, 1],
        [4, 0, 1, 2],
        [0, 1, 2, 3]])

In [97]:
nn[leaf,:]=idx_update.float()

In [98]:
nn

tensor([[1., 2., 3., 4.],
        [2., 3., 4., 0.],
        [3., 4., 0., 1.],
        [4., 0., 1., 2.],
        [0., 1., 2., 3.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [99]:
#above we see the first 5 rows updated to their neighbours

In [None]:
'''
i would advise the following:
first obtain idx_update for each leaf and repeatedly concatenate them together
then obtain all the leaf indices and repeatedly concatenate them together
finally update the NN graph in 1 go with nn[leaf,:]=idx_update
this way you don't keep editing the NN graph you perform less operations per loop, ie faster
'''