In [197]:
import numpy as np

def euclidean_distance(x, y):
    n = x.shape[0]
    m = y.shape[0]
    PRECISION = 8
    xNormsSquared = np.tile(np.apply_along_axis(lambda x: np.linalg.norm(x) ** 2, 1, x).reshape(1, n).T, m)
    yNormsSquared = np.tile(np.apply_along_axis(lambda x: np.linalg.norm(x) ** 2, 1, y), (n, 1))
    xyProduct = x @ y.T
    
    assert(xNormsSquared.shape == (n, m))
    assert(yNormsSquared.shape == (n, m))
    assert(xyProduct.shape == (n, m))
    result = np.round(xNormsSquared + yNormsSquared - 2 * xyProduct, PRECISION)
    print(f"result = {result}")
    return result ** 0.5


def cosine_distance(x, y):
    n = x.shape[0]
    m = y.shape[0]
    
    xNorms = np.tile(np.apply_along_axis(np.linalg.norm, 1, x).reshape(1, n).T, m)
    yNorms = np.tile(np.apply_along_axis(np.linalg.norm, 1, y), (n, 1))
    xyProduct = x @ y.T
    
    assert(xNorms.shape == (n, m))
    assert(yNorms.shape == (n, m))
    assert(xyProduct.shape == (n, m))
    
    return 1.0 - xyProduct / xNorms / yNorms


def get_best_indices(ranks, top, axis=1):
    indexes = np.argpartition(ranks, top - 1, axis=axis).take(indices=range(0, top), axis=axis)
    values = np.take_along_axis(ranks, indexes, axis=axis)
    cuttedSortedIndexes = np.argsort(values, axis=axis)
    return (np.take_along_axis(indexes, cuttedSortedIndexes, axis=axis), np.take_along_axis(values, cuttedSortedIndexes, axis=axis))


In [129]:
indexes = np.argpartition(x, top - 1, axis=1).take(indices=range(0, top), axis=1)
print(indexes)
values = np.take_along_axis(x, indexes, axis=1).sort
print(values)

[[2 3 7 6 5]
 [4 0 1 5 2]
 [8 5 4 3 6]
 [3 4 8 9 2]
 [3 4 0 9 2]
 [0 1 5 6 7]
 [3 4 1 9 2]
 [1 6 8 9 2]
 [5 9 2 6 0]
 [1 9 2 3 0]]
[[2 1 3 3 4]
 [1 2 5 4 6]
 [0 0 0 1 2]
 [0 0 1 3 3]
 [0 0 0 0 0]
 [1 2 3 3 7]
 [3 5 2 5 5]
 [0 0 2 2 2]
 [3 5 2 1 5]
 [1 0 0 0 2]]


In [187]:
class NearestNeighborsFinder:
    def __init__(self, n_neighbors, metric="euclidean"):
        self.n_neighbors = n_neighbors

        if metric == "euclidean":
            self._metric_func = euclidean_distance
        elif metric == "cosine":
            self._metric_func = cosine_distance
        else:
            raise ValueError("Metric is not supported", metric)
        self.metric = metric

    def fit(self, X, y=None):
        self._X = X
        return self

    def kneighbors(self, X, return_distance=False):
        axis = 1
        top = self.n_neighbors
        m = X.shape[0]

        distanceMatrix = self._metric_func(X, self._X)
        indexes = np.argpartition(distanceMatrix, top - 1, axis=axis).take(indices=range(0, top), axis=axis)
        values = np.take_along_axis(distanceMatrix, indexes, axis=axis)
        cuttedSortedIndexes = np.argsort(values, axis=axis)
        
        resultDistances = np.take_along_axis(values, cuttedSortedIndexes, axis=axis)
        resulIndicies = np.take_along_axis(indexes, cuttedSortedIndexes, axis=axis)
        resultIndicies = np.tile(resulIndicies[:,0].reshape(-1, 1), top)
        assert(resultDistances.shape == (m, top))

        if return_distance:
            return (resultDistances, resulIndicies)
        else:
            return resulIndicies

In [198]:
seed = np.random.RandomState(9872)
X = seed.permutation(500).reshape(10, -1)

nn = NearestNeighborsFinder(n_neighbors=1, metric='euclidean')
nn.fit(X)

distances, indices = nn.kneighbors(X, return_distance=True)
assert(np.all(np.arange(len(X))[:, np.newaxis] == indices))
assert(np.all(np.zeros(len(X))[:, np.newaxis] == distances))

result = [[     -0. 2144347. 2571254. 1871857. 2083612. 1714306. 1967429. 3131056.
  2240076. 1915921.]
 [2144347.       0. 2459455. 1944190. 2057505. 2363739. 1666290. 2368285.
  2382357. 1942834.]
 [2571254. 2459455.       0. 2615743. 1841178. 2681722. 2738177. 2357054.
  2330298. 1862853.]
 [1871857. 1944190. 2615743.      -0. 1963535. 1793177. 1558122. 2264665.
  1720663. 1965048.]
 [2083612. 2057505. 1841178. 1963535.       0. 1888538. 1953349. 2340966.
  1986812. 1553503.]
 [1714306. 2363739. 2681722. 1793177. 1888538.       0. 1503077. 2776324.
  2077204. 2435519.]
 [1967429. 1666290. 2738177. 1558122. 1953349. 1503077.       0. 2362557.
  2033321. 1343426.]
 [3131056. 2368285. 2357054. 2264665. 2340966. 2776324. 2362557.       0.
  2740334. 1975095.]
 [2240076. 2382357. 2330298. 1720663. 1986812. 2077204. 2033321. 2740334.
       -0. 1982505.]
 [1915921. 1942834. 1862853. 1965048. 1553503. 2435519. 1343426. 1975095.
  1982505.       0.]]


In [195]:
1**2+17**2+4**2+39**2+8**2

1891

In [194]:
X

array([[22,  1, 19, 46, 48],
       [21, 18, 15,  7, 40],
       [27, 20,  6,  5,  4],
       [12, 34, 29, 35, 44],
       [26, 37, 49, 13,  8],
       [14, 24, 31, 38, 10],
       [41,  3, 42, 43, 47],
       [ 0, 32, 30, 17, 36],
       [23,  2, 33, 16,  9],
       [39, 28, 45, 25, 11]])

In [180]:
(-0.0)**0.

0.0