In [None]:
import tensorflow as tf
from scipy.spatial.distance import euclidean
from tensorflow.keras import datasets, layers, models
import numpy as np


(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()


train_images = train_images / 255.0
test_images = test_images / 255.0


def create_resnet18():
    model = models.Sequential()
    model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    return model


model = create_resnet18()


model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))


train_embeddings = model.predict(train_images)


In [None]:
test_embeddings = model.predict(test_images)
test_embeddings.shape

In [None]:
resnet_train_data = []
for i in range(len(train_embeddings)):
  resnet_train_data.append((train_embeddings[i],train_labels[i]))
resnet_test_data = []
for i in range(len(test_embeddings)):
  resnet_test_data.append((test_embeddings[i],test_labels[i]))

In [None]:
import matplotlib.pyplot as plt
def find_knn_and_plot_image(k_value:int =10,train_data=resnet_train_data,test_data=resnet_test_data[0]):
    correct = 0
    x_test,y_test = test_data
    point = np.array(x_test)

    distance_with_label_and_index = []

    for i,(x_train,y_train) in enumerate(train_data):
      train_point = np.array(x_train)
      distance_with_label_and_index.append(((y_train,i),np.linalg.norm(point-train_point)))

    #sorting based on distance
    distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])
    k_nearest_points = distance_with_label_and_index_sorted[0:k_value]

    
    #calculating accuracy
    fig, axes = plt.subplots(1, k_value, figsize=(15, 3))  # Adjust figsize as needed
    for i,((label,index),distance) in enumerate(k_nearest_points):
        if(label == y_test):
            correct+=1
        axes[i].imshow(train_images[index])
        axes[i].set_title(f"Index: {index}")
        axes[i].axis('off')
    plt.show()
    return correct/(k_value)

In [None]:
index =91
plt.imshow(test_images[index])
print(test_labels[index])
find_knn_and_plot_image(10,resnet_train_data,resnet_test_data[index])

In [None]:
def find_knn(k_value:int =10,train_data=resnet_train_data,test_data=resnet_test_data):
  results =[]
  correct = 0
  for idx, (x_test,y_test) in enumerate(tqdm(test_data)):
    point = np.array(x_test)

    distance_with_label_and_index = []

    for i,(x_train,y_train) in enumerate(train_data):
      train_point = np.array(x_train)
      distance_with_label_and_index.append(((y_train,i),np.linalg.norm(point-train_point)))

    #sorting based on distance
    distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])
    k_nearest_points = distance_with_label_and_index_sorted[0:k_value]


    #calculating accuracy
    results.append([])
    for i,((label,index),distance) in enumerate(k_nearest_points):
      if(label == y_test):
        correct+=1
      results[idx].append((label,index))
  return results,correct/(k_value*len(test_data))

In [None]:
results,accuracy = find_knn(10)

In [None]:
accuracy