In [1]:
import pandas as pd
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.modules.distance import PairwiseDistance
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data
from tqdm import tnrange
from tqdm import tqdm_notebook

# The Mighty KNN

In [2]:
def euk2(x,y):
    d = x-y
    return d @ d

In [3]:
def lnorm(p): 
    return lambda x,y: (x-y).pow(p).sum().pow(1/p)

In [4]:
def dist_mat(X, Y, d=None):
    if d == euk2:
        D = X @ Y.T
        return torch.sum(X**2,1,keepdim=True) + torch.sum(Y**2,1,keepdim=True).T - 2*D
    else:
        return torch.Tensor([d(X[i], Y[i]) for i in range(X.shape[0]) for j in range(Y.shape[0])])               

In [6]:
class KNN(nn.Module):
    def __init__(self, X_train, y_train, k=5, morphism = lambda x: x, distance=euk2, batch_size=None):
        
        self.X = morphism(X_train)
        self.y = y_train
        self.M = morphism
        self.d = distance
        self.b = batch_size if batch_size else X_train.shape[0]
        self.k = k
        
    def forward(self, Y, fast=False):
        ind = torch.randperm(self.X.shape[0])[:self.b]
        Y = self.M(Y)
        if fast:
            dm = dist_mat(Y, self.X[ind], self.d)
            _, I = dm.sort(1)
            ret, _ = y[ind][I[:,:self.k]].mode(1)
        else:
            ret = torch.zeros(Y.shape[0])
            for i in range(Y.shape[0]):
                a = torch.zeros(self.k, dtype=torch.long)
                d = torch.ones(self.k) * float('inf')
                for j in ind:
                    nd = self.d(Y[i], self.X[j])
                    if d.max() > nd:
                        a[d.argmax()] = j
                        d[d.argmax()] = nd
                ret[i], _ = self.y[a].mode()
        return ret

In [7]:
X = torch.rand(100,10)

In [8]:
y = torch.ceil(torch.rand(100)*5)

In [9]:
Y = torch.rand(20,10)

In [10]:
knn = KNN(X, y)

In [11]:
knn.forward(Y)

tensor([1., 2., 4., 1., 1., 5., 1., 3., 1., 5., 1., 2., 1., 4., 1., 5., 1., 3.,
        1., 5.])

In [12]:
knn.forward(Y, fast=True)

tensor([1., 2., 4., 1., 1., 5., 1., 3., 1., 5., 1., 2., 1., 4., 1., 5., 1., 3.,
        1., 5.])