In [1]:
import numpy as np

from sklearn.neighbors import NearestNeighbors

# Initialise a random array of 1000 3D points

In [2]:
X = np.random.random((1000, 3))
X.shape

(1000, 3)

# Use broadcasting to find pairwise differences

In [3]:
diff = X.reshape(1000, 1, 3) - X
diff.shape

(1000, 1000, 3)

# Aggregate to find pairwise differences

In [4]:
D = (diff ** 2).sum(axis=2)
D.shape

(1000, 1000)

# Set diagonal to infinity to skip self-neighbours

In [5]:
i = np.arange(1000)
D[i, i] = np.inf

# Print the indices of the nearest neighbor

In [6]:
i = np.argmin(D, axis=1)
print(i[:10])

[338 792 376 995 430 506 334  64 325 445]


# Double check with scikit-learn

In [7]:
d, i = NearestNeighbors().fit(X).kneighbors(X, 2)
print(i[:10, 1])

[338 792 376 995 430 506 334  64 325 445]


# With a smaller matrix to show what's happening

In [8]:
X = np.random.random((5, 3))
print(X)

[[0.07260753 0.68026347 0.62210483]
 [0.61062272 0.15505244 0.22953415]
 [0.21876177 0.46364837 0.79854858]
 [0.78720567 0.53666684 0.9305735 ]
 [0.0879056  0.32535417 0.24104592]]


In [9]:
diff = X.reshape(5, 1, 3) - X
print(diff)

[[[ 0.          0.          0.        ]
  [-0.53801518  0.52521103  0.39257068]
  [-0.14615424  0.2166151  -0.17644375]
  [-0.71459814  0.14359663 -0.30846867]
  [-0.01529807  0.3549093   0.38105891]]

 [[ 0.53801518 -0.52521103 -0.39257068]
  [ 0.          0.          0.        ]
  [ 0.39186094 -0.30859593 -0.56901443]
  [-0.17658296 -0.3816144  -0.70103935]
  [ 0.52271711 -0.17030173 -0.01151177]]

 [[ 0.14615424 -0.2166151   0.17644375]
  [-0.39186094  0.30859593  0.56901443]
  [ 0.          0.          0.        ]
  [-0.5684439  -0.07301847 -0.13202493]
  [ 0.13085617  0.1382942   0.55750266]]

 [[ 0.71459814 -0.14359663  0.30846867]
  [ 0.17658296  0.3816144   0.70103935]
  [ 0.5684439   0.07301847  0.13202493]
  [ 0.          0.          0.        ]
  [ 0.69930007  0.21131267  0.68952759]]

 [[ 0.01529807 -0.3549093  -0.38105891]
  [-0.52271711  0.17030173  0.01151177]
  [-0.13085617 -0.1382942  -0.55750266]
  [-0.69930007 -0.21131267 -0.68952759]
  [ 0.          0.          0.  

In [10]:
D = (diff ** 2).sum(axis=2)
print(D)

[[0.         0.7194187  0.09941556 0.62642341 0.27140054]
 [0.7194187  0.         0.57256387 0.66826727 0.30236838]
 [0.09941556 0.57256387 0.         0.34589075 0.34705784]
 [0.62642341 0.66826727 0.34589075 0.         1.00912193]
 [0.27140054 0.30236838 0.34705784 1.00912193 0.        ]]


In [11]:
i = np.arange(5)

D[i, i] = np.inf
print(D)

[[       inf 0.7194187  0.09941556 0.62642341 0.27140054]
 [0.7194187         inf 0.57256387 0.66826727 0.30236838]
 [0.09941556 0.57256387        inf 0.34589075 0.34705784]
 [0.62642341 0.66826727 0.34589075        inf 1.00912193]
 [0.27140054 0.30236838 0.34705784 1.00912193        inf]]


In [12]:
i = np.argmin(D, axis=1)
print(i)

[2 4 0 2 0]
