In [None]:

!pip install datasets

In [None]:
from huggingface_hub import HfApi, hf_hub_url
from torchvision import models, transforms
from PIL import Image
import torch
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

In [None]:
from datasets import load_dataset

dataset = load_dataset("keremberke/shoe-classification",'full')


In [None]:

def show_images(dataset, num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    for i in range(num_images):
        img = dataset["train"][i]["image"]
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()

show_images(dataset)


In [None]:

def preprocess_image(image):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess(image).unsqueeze(0)

In [None]:

model = models.alexnet(pretrained=True)
model.eval()


In [None]:

def get_embedding(image, model):
    with torch.no_grad():
        return model(preprocess_image(image)).numpy().flatten()

In [None]:

model.classifier[6] = torch.nn.Identity()
embeddings = [get_embedding(dataset["train"][i]["image"], model) for i in range(len(dataset["train"]))]

In [None]:
embeddings[0].shape

In [None]:

nn_model = NearestNeighbors(n_neighbors=6, algorithm='ball_tree')
nn_model.fit(embeddings)

In [None]:

def retrieve_images(index, nn_model, dataset):
    index = int(index)

    distances, indices = nn_model.kneighbors([embeddings[index]])

    indexed_distances = [(int(i), dist) for i, dist in zip(indices[0], distances[0]) if i != index]

    indexed_distances.sort(key=lambda x: x[1])

    retrieved_images = [dataset["train"][idx]["image"] for idx, _ in indexed_distances]
    return retrieved_images

In [None]:
test_index= 94
retrieved_images = retrieve_images(test_index, nn_model, dataset)

plt.imshow(dataset["train"][test_index]["image"])
plt.title("Test Image")
plt.axis('off')
plt.show()

show_images({"train": [{"image": img} for img in retrieved_images]})

In [None]:
test_index= 65
retrieved_images = retrieve_images(test_index, nn_model, dataset)

plt.imshow(dataset["train"][test_index]["image"])
plt.title("Test Image")
plt.axis('off')
plt.show()

show_images({"train": [{"image": img} for img in retrieved_images]})

In [None]:
test_index= 358
retrieved_images = retrieve_images(test_index, nn_model, dataset)
plt.imshow(dataset["train"][test_index]["image"])
plt.title("Test Image")
plt.axis('off')
plt.show()

show_images({"train": [{"image": img} for img in retrieved_images]})

# Deeper model

In [None]:
def get_embedding(image, model):
    with torch.no_grad():
        output = model(preprocess_image(image))
        return output.cpu().numpy().flatten()

In [None]:

model2 = models.resnet152(pretrained=True)
model2.eval()


In [None]:
model2.fc = torch.nn.Identity()
embeddings2 = [get_embedding(dataset["train"][i]["image"], model2) for i in range(len(dataset["train"]))]

In [None]:
nn_model2 = NearestNeighbors(n_neighbors=6, algorithm='ball_tree')
nn_model2.fit(embeddings2)

In [None]:

def retrieve_images2(index, nn_model, dataset):
    index = int(index)

    distances, indices = nn_model.kneighbors([embeddings2[index]])

    indexed_distances = [(int(i), dist) for i, dist in zip(indices[0], distances[0]) if i != index]

    indexed_distances.sort(key=lambda x: x[1])

    retrieved_images = [dataset["train"][idx]["image"] for idx, _ in indexed_distances]
    return retrieved_images

In [None]:
test_index= 94
retrieved_images = retrieve_images2(test_index, nn_model2, dataset)
plt.imshow(dataset["train"][test_index]["image"])
plt.title("Test Image")
plt.axis('off')
plt.show()

show_images({"train": [{"image": img} for img in retrieved_images]})

In [None]:
test_index= 65
retrieved_images = retrieve_images2(test_index, nn_model2, dataset)
plt.imshow(dataset["train"][test_index]["image"])
plt.title("Test Image")
plt.axis('off')
plt.show()

show_images({"train": [{"image": img} for img in retrieved_images]})

In [None]:
test_index= 358
retrieved_images = retrieve_images2(test_index, nn_model2, dataset)
plt.imshow(dataset["train"][test_index]["image"])
plt.title("Test Image")
plt.axis('off')
plt.show()

show_images({"train": [{"image": img} for img in retrieved_images]})