In [None]:
import numpy as np
import matplotlib.pyplot as plt


class KnnClassifier:
    # 1
    def __init__(self, train_images, train_labels):
        self.train_images = train_images
        self.train_labels = train_labels
    # 2

    def classify_image(self, test_image, num_neighbors=3, metric='l2'):
        if metric == 'l2':
            distances = np.sqrt(np.sum((self.train_images - test_image) ** 2, axis=1))
        else:
            distances = np.sum(abs(self.train_images - test_image), axis = 1)
        sort_indexes = np.argsort(distances)
        best_neighbors = self.train_labels[sort_indexes[:num_neighbors]]
        histo = np.bincount(best_neighbors)
        return np.argmax(histo)


train_images = np.loadtxt('data/train_images.txt')  # load
train_labels = np.loadtxt('data/train_labels.txt', 'int')  # load labels as int
test_images = np.loadtxt('data/test_images.txt')
test_labels = np.loadtxt('data/test_labels.txt', 'int')

# 3
classifier = KnnClassifier(train_images, train_labels)
predictions = []
for image in test_images:
    predicted_label = classifier.classify_image(image)
    predictions.append(predicted_label)
predicted_labels = np.array(predictions)
np.savetxt('predictii_3nn_l2_mnist.txt', predicted_labels)
classification_ok = np.sum(predicted_labels == test_labels)
total = len(test_labels)
print(f"Accuracy is {classification_ok / total * 100}%")

In [None]:
def accuracy(type_of_metric):
    acc = []
    for i in [1, 3, 5, 7, 9]:
        classification_ok = 0
        for test_image, test_label in zip(test_images, test_labels):
            predicted_label = classifier.classify_image(test_image, i, type_of_metric)
            if predicted_label == test_label:
                classification_ok += 1
        total = len(test_images)
        result = classification_ok / total
        acc.append(result)
    np.savetxt(f"acuratete_{type_of_metric}.txt", np.array(acc))


accuracy('l2')
l2_gr = np.loadtxt('acuratete_l2.txt')
plt.plot(l2_gr)
accuracy('l1')
l1_gr = np.loadtxt('acuratete_l1.txt')
plt.plot(l1_gr)
plt.gca().legend(('L2', 'L1'))
plt.xlabel('number of neighbors')
plt.ylabel('accuracy')
plt.show()
