Nearest neighbor clustering, in PyTorch

In [2]:
import torchvision
import torch
from PIL import Image

In [3]:
size = (128, 128)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size),
    torchvision.transforms.ToTensor()
])

In [19]:
train_dataset = torchvision.datasets.Flowers102("./flowers", "train", transform=transform, download=True)
test_dataset = torchvision.datasets.Flowers102("./flowers", "test", transform=transform, download=True)

In [5]:
def visualize_image(image):
    image = image.permute(1, 2, 0) * 255
    image = image.to(torch.uint8).numpy()
    image = Image.fromarray(image)
    image.show()

In [7]:
visualize_image(train_dataset[0][0])

In [10]:
# let's grab the first 20 images, 10 of which are in class 0 and 10 of which are in class 1
class_01 = list(train_dataset)[:20]

In [14]:
def nn_classifier(x):
    """Computes the distances between a given observation and the training set,
    and then returns the class of the closest observation.
    """
    # calculate RMSE between x and each image in the training set.
    # we square each element of the difference, sum them, take the mean,
    # and then take the square root
    distance_label_tuples = [
        ((x - img).pow(2).sum().sqrt(), label) for img, label in class_01
    ]
    # get the index of the smallest distance
    return min(distance_label_tuples)[1]

In [15]:
nn_classifier(train_dataset[0][1])

1

In [20]:
num_samples = 20
accuracy = sum(nn_classifier(img) == label for img, label in list(test_dataset)[:num_samples]) / num_samples

In [21]:
accuracy

0.4

## K-nearest-neighbors classification/regression

Find $k$ closest values in dataset $D$ and return the most common label (classification) or average value (regression).

In [25]:
def knn_classifier(x, k=3):
    """Computes the distances between a given observation and the training set,
    and then returns the class of the closest "k" observations.
    """
    # calculate RMSE between x and each image in the training set.
    # we square each element of the difference, sum them, take the mean,
    # and then take the square root
    distance_label_tuples = [
        ((x - img).pow(2).sum().sqrt(), label) for img, label in class_01
    ]
    # get the k nearest distances
    k_smallest = sorted(distance_label_tuples)[:k]

    # return the most common label across the k nearest neighbors.
    mean_label = sum(label for _, label in k_smallest) / k
    print(f"Mean label: {mean_label}")
    return round(mean_label) # rounds to nearest int, so 0 or 1

In [27]:
knn_classifier(test_dataset[0][1], k=5)

Mean label: 0.8


1