In [None]:
def ivf_accuracy(indices_test, indices_truth):
  '''
  Compares the test and ground truth indices (rows = KNN for each point in dataset)
  Returns accuracy: proportion of correct nearest neighbours
  '''
  N, k = indices_test.shape
  
  # Calculate number of correct nearest neighbours
  accuracy = 0
  for i in range(k):
    accuracy += torch.sum(indices_test == indices_truth).float()/N
    indices_truth = torch.roll(indices_truth, 1, -1) # Create a rolling window (index positions may not match)
  accuracy = float(accuracy/k) # percentage accuracy

  return accuracy

In [None]:
def knn_accuracy_x_y(indices_test, x, y):
  '''
  Compares the test and ground truth indices (rows = KNN for each point in dataset)
  Returns accuracy: proportion of correct nearest neighbours
  '''
  N, k = indices_test.shape
  
  indices_truth = torch.argsort(((y.unsqueeze(1)-x.unsqueeze(0))**2).sum(-1),dim=1)
  indices_truth = indices_truth[:,:5]
  
  # Calculate number of correct nearest neighbours
  accuracy = 0
  for i in range(k):
    accuracy += torch.sum(indices_test == indices_truth).float()/N
    indices_truth = torch.roll(indices_truth, 1, -1) # Create a rolling window (index positions may not match)
  accuracy = float(accuracy/k) # percentage accuracy

  return accuracy

In [None]:
def init_accuracy(data, graph, k_distances):
  '''
  Takes in data and graph to check accuracy of graph's assigned k nearest neighbours
  Uses torch brute force to find actual k nearest neighbours
  Returns accuracy: proportion of correct nearest neighbours
  Also returns distance error: (average_distance-true_distances)/true_distance (of k nearest neighbours)
  '''
  N, k = graph.shape
  graph = torch.sort(graph,dim=1)[0] # sort each row of graph

  # Calculate true distances, indices
  d=((data.unsqueeze(1)-data.unsqueeze(0))**2).sum(-1)+torch.Tensor([float('inf')]).repeat(len(data)).diag() # Infinity is added to diagonal
  true_distances, true_indices = torch.sort(d,dim=1)

  # get k nearest neighbours
  true_indices = true_indices[:,:k]
  true_distances = true_distances[:,:k]
  
  # Calculate number of correct nearest neighbours
  accuracy = 0
  for i in range(k):
    accuracy += torch.sum(graph == true_indices).float()
    torch.roll(true_indices, 1, -1) # Create a rolling window (index positions may not match)
  accuracy = float(accuracy/(N*k)) # percentage accuracy

  # Calculate accuracy of distances
  true_average = torch.mean(true_distances)
  graph_average = torch.mean(k_distances)
  distance_error = float((graph_average-true_average)/true_average)

  return accuracy, distance_error