In [1]:
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
import numpy as np

In [2]:
mnist_transforms = [
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(transforms.Lambda(torch.flatten)), # flatten
]
transform = transforms.Compose(mnist_transforms)

In [3]:
train_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True, "batch_size": 128}
test_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": False, "batch_size": 256}

data_path = "./data"
train_dataset = MNIST(data_path, train=True, download=True, transform=transform)
test_dataset = MNIST(data_path, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

In [4]:
def dataloader_to_numpy(loader, feature_extractor=None):
    X = []
    y = []
    for xs, ys in loader:
        if feature_extractor is not None:
            xs = feature_extractor(xs)
        X.append(xs.numpy())
        y.extend(ys.tolist())
    return np.vstack(X).astype(np.float64), np.array(y)

In [5]:
X_train, y_train = dataloader_to_numpy(train_loader)
X_test, y_test = dataloader_to_numpy(test_loader)

In [6]:
X_train.shape, y_train.shape

((60000, 784), (60000,))

### sklearn's slow version for reference


In [7]:
from sklearn.neighbors import KNeighborsClassifier
import time

In [8]:
train_begin = time.time()
knn = KNeighborsClassifier(n_neighbors=1, weights="uniform", n_jobs=-1)
knn.fit(X_train, y_train)
train_end = time.time()

# test
inference_begin = time.time()
preds = knn.predict(X_test)
inference_end = time.time()

num_corrects_test = np.sum(preds == y_test)
test_acc = num_corrects_test / len(y_test) * 100


print(f"K:{1}, d:{'uniform'}, test acc:{num_corrects_test}/{len(y_test)} ({test_acc}%)")

K:1, d:uniform, test acc:9691/10000 (96.91%)


In [9]:
print(f"training: {train_end-train_begin:.3f}sec, inference: {inference_end-inference_begin:.3f}sec")

training: 0.041sec, inference: 16.484sec


### faiss

In [10]:
import faiss

Select best indexes by following the guide
https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index

In [11]:
d = X_train.shape[1]
k = 1
M = 32

train_begin = time.time()
# index = faiss.IndexFlatL2(d)  # exact case
index = faiss.IndexHNSWFlat(d, M)
index.add(X_train.astype("float32")) 
train_end = time.time()

In [12]:
inference_begin = time.time()
D, I = index.search(X_test.astype("float32"), k) 
inference_end = time.time()

In [13]:
predicted_labels = y_train[I].flatten()

In [14]:
np.mean(predicted_labels==y_test) * 100

96.58

In [15]:
print(f"training: {train_end-train_begin:.3f}sec, inference: {inference_end-inference_begin:.3f}sec")

training: 4.880sec, inference: 0.528sec
