# kNN (K-Nearest Neighbors)

In [1]:
import torch
import pandas as pd

In [2]:
chunks = torch.rand(2,10,20) #Sample chunks
e_db = pd.DataFrame(torch.rand(20000,1,20,20).tolist()) #Sample embedding database

In [3]:
#Chunks might be getting built wrong where its cutting them in half the wrong way

def get_kNN(chunks, e_db, k = 2):
    """
    Input: chunks - tensor containing initial data
            e_db - dataframe containing embeddings
    Description: find k-nearest-neighbours of input tensor
    Output: tensor containing the k-nearest-neighbours of input tensor
    """
    neighbours = torch.tensor([])
    for i, chunk in enumerate(chunks):
        e_db['L2'] = e_db.apply(lambda x:torch.linalg.norm(chunk - torch.tensor(x[0][chunk.size(0) * i:chunk.size(0) * (i + 1)])).item(), axis=1)
        kNN = torch.tensor([e_db.nsmallest(k, ['L2'])[0].tolist()])
        neighbours = torch.cat([neighbours, kNN])
    return neighbours

neighbours = get_kNN(chunks, e_db)
print(neighbours, '\n', neighbours.size())

tensor([[[[0.8339, 0.4058, 0.4865,  ..., 0.6056, 0.5309, 0.8444],
          [0.7127, 0.6784, 0.0095,  ..., 0.6647, 0.1412, 0.1697],
          [0.5732, 0.0255, 0.7633,  ..., 0.3600, 0.2175, 0.6936],
          ...,
          [0.5702, 0.5199, 0.6894,  ..., 0.1080, 0.9961, 0.1616],
          [0.1445, 0.5681, 0.4068,  ..., 0.4583, 0.7313, 0.2490],
          [0.3081, 0.8959, 0.4925,  ..., 0.1634, 0.4041, 0.9460]],

         [[0.0771, 0.7256, 0.3806,  ..., 0.9940, 0.9824, 0.1998],
          [0.5780, 0.7712, 0.1326,  ..., 0.3795, 0.7418, 0.3705],
          [0.9809, 0.8807, 0.6996,  ..., 0.0742, 0.1150, 0.0569],
          ...,
          [0.2882, 0.7255, 0.1155,  ..., 0.1654, 0.0761, 0.3060],
          [0.4759, 0.4365, 0.9990,  ..., 0.3860, 0.9442, 0.2466],
          [0.2106, 0.2431, 0.9540,  ..., 0.1373, 0.5787, 0.7914]]],


        [[[0.2835, 0.4034, 0.2680,  ..., 0.7233, 0.4990, 0.2989],
          [0.1439, 0.8564, 0.1550,  ..., 0.7646, 0.1905, 0.0646],
          [0.9763, 0.8006, 0.7100,  ...,

In [4]:
def new_kNN(chunks, e_db, k = 2):
    """
    Input: chunks - tensor containing initial data
            e_db - dataframe containing embeddings
    Description: find k-nearest-neighbours of input tensor
    Output: tensor containing the k-nearest-neighbours of input tensor
    """
    e_db = torch.tensor(e_db[0])
    neighbours = torch.tensor([])
    for i, chunk in enumerate(chunks):
        neighbours = torch.cat(
            [
                neighbours, 
                e_db[
                    torch.linalg.matrix_norm(
                        chunk - e_db[:, chunk.size(0) * i : chunk.size(0) * (i + 1)] #Compare slice with chunk
                    ).topk(k, largest = False).indices #Index of k nearest neighbours
                ][None, :, :]
            ]
        )
    return neighbours
    
neighbours = new_kNN(chunks, e_db)
print(neighbours, '\n', neighbours.size())

tensor([[[[0.8339, 0.4058, 0.4865,  ..., 0.6056, 0.5309, 0.8444],
          [0.7127, 0.6784, 0.0095,  ..., 0.6647, 0.1412, 0.1697],
          [0.5732, 0.0255, 0.7633,  ..., 0.3600, 0.2175, 0.6936],
          ...,
          [0.5702, 0.5199, 0.6894,  ..., 0.1080, 0.9961, 0.1616],
          [0.1445, 0.5681, 0.4068,  ..., 0.4583, 0.7313, 0.2490],
          [0.3081, 0.8959, 0.4925,  ..., 0.1634, 0.4041, 0.9460]],

         [[0.0771, 0.7256, 0.3806,  ..., 0.9940, 0.9824, 0.1998],
          [0.5780, 0.7712, 0.1326,  ..., 0.3795, 0.7418, 0.3705],
          [0.9809, 0.8807, 0.6996,  ..., 0.0742, 0.1150, 0.0569],
          ...,
          [0.2882, 0.7255, 0.1155,  ..., 0.1654, 0.0761, 0.3060],
          [0.4759, 0.4365, 0.9990,  ..., 0.3860, 0.9442, 0.2466],
          [0.2106, 0.2431, 0.9540,  ..., 0.1373, 0.5787, 0.7914]]],


        [[[0.2835, 0.4034, 0.2680,  ..., 0.7233, 0.4990, 0.2989],
          [0.1439, 0.8564, 0.1550,  ..., 0.7646, 0.1905, 0.0646],
          [0.9763, 0.8006, 0.7100,  ...,

## Small data sample

In [5]:
chunks = torch.rand(2,10,20) #Sample chunks
e_db = pd.DataFrame(torch.rand(20,1,20,20).tolist()) #Sample embedding database

In [6]:
%timeit -n 10 get_kNN(chunks, e_db)

3.45 ms ± 487 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
%timeit -n 10 new_kNN(chunks, e_db)

630 µs ± 283 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Medium data sample

In [8]:
chunks = torch.rand(2,10,20) #Sample chunks
e_db = pd.DataFrame(torch.rand(20000,1,20,20).tolist()) #Sample embedding database

In [9]:
%timeit -n 10 get_kNN(chunks, e_db)

926 ms ± 67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%timeit -n 10 new_kNN(chunks, e_db)

327 ms ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Large data sample

In [11]:
chunks = torch.rand(2,10,20) #Sample chunks
e_db = pd.DataFrame(torch.rand(500000,1,20,20).tolist()) #Sample embedding database

In [12]:
%timeit -n 10 get_kNN(chunks, e_db)

1min 25s ± 9.3 s per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%timeit -n 10 new_kNN(chunks, e_db)

KeyboardInterrupt: 

## More neighbours

In [None]:
chunks = torch.rand(2,10,20) #Sample chunks
e_db = pd.DataFrame(torch.rand(20000,1,20,20).tolist()) #Sample embedding database

In [None]:
%timeit -n 10 get_kNN(chunks, e_db, k=100)

In [None]:
%timeit -n 10 new_kNN(chunks, e_db, k=100)

## Results

In [None]:
df = pd.DataFrame(
    {
        'v1':[3.45, 926, ],
        'v2':[0.63, 327, ]
    },
    index=[20, 20000, 500000]
)

df.plot.line()