In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

# Define transforms for data augmentation and normalization
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for batch processing
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1000, shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 44877109.95it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
import torch.nn as nn
import torchvision.models as models
import numpy as np

# Load pretrained ResNet model and modify it to act as a feature extractor
resnet18 = models.resnet18(pretrained=True)
resnet18 = nn.Sequential(*list(resnet18.children())[:-1])  # Remove the final classification layer

# Function to extract features
def extract_features(dataloader):
    resnet18.eval()  # Set model to evaluation mode
    features = []
    labels = []
    with torch.no_grad():
        for inputs, targets in dataloader:
            outputs = resnet18(inputs).squeeze()
            features.append(outputs.cpu().numpy())
            labels.append(targets.numpy())
    return np.vstack(features), np.hstack(labels)

# Extract features from train and test set
train_features, train_labels = extract_features(trainloader)
test_features, test_labels = extract_features(testloader)


In [6]:
from sklearn.cluster import KMeans
import numpy as np

# Apply K-means clustering to training set
kmeans = KMeans(n_clusters=10, random_state=42)

train_clusters = kmeans.fit_predict(train_features)



In [7]:

# Assign each test image to its nearest cluster
test_clusters = kmeans.predict(test_features)


In [8]:
from sklearn.metrics.pairwise import cosine_similarity

def get_top_k_matches(test_feature, cluster_indices, k=50):
    # Get the features of images belonging to the same cluster
    cluster_features = train_features[cluster_indices]
    
    # Compute cosine similarity between the test image and cluster images
    similarities = cosine_similarity(test_feature.reshape(1, -1), cluster_features).flatten()
    
    # Get the top k most similar images
    top_k_indices = np.argsort(similarities)[::-1][:k]
    return cluster_indices[top_k_indices]

# For each test image, find the top 50 matches
top_k_matches = []
for i, test_feature in enumerate(test_features):
    # Find the training images belonging to the same cluster
    cluster_indices = np.where(train_clusters == test_clusters[i])[0]
    
    # Get the top 50 matches based on cosine similarity
    top_k_matches.append(get_top_k_matches(test_feature, cluster_indices))


In [9]:
def precision_at_k(true_label, top_k_labels, k):
    top_k = top_k_labels[:k]
    correct = np.sum(top_k == true_label)
    return correct / k

def mean_average_precision(true_label, top_k_labels):
    # Calculate precision at each rank and then compute average precision
    precisions = []
    correct = 0
    for i, label in enumerate(top_k_labels):
        if label == true_label:
            correct += 1
            precisions.append(correct / (i + 1))
    return np.mean(precisions) if precisions else 0

# Evaluate for all test images
precision_10 = []
precision_50 = []
mean_ap = []
for i, matches in enumerate(top_k_matches):
    true_label = test_labels[i]
    matched_labels = train_labels[matches]
    
    precision_10.append(precision_at_k(true_label, matched_labels, 10))
    precision_50.append(precision_at_k(true_label, matched_labels, 50))
    mean_ap.append(mean_average_precision(true_label, matched_labels))

# Report final metrics
print(f'Mean Precision@10: {np.mean(precision_10):.4f}')
print(f'Mean Precision@50: {np.mean(precision_50):.4f}')
print(f'Mean Average Precision: {np.mean(mean_ap):.4f}')


Mean Precision@10: 0.4224
Mean Precision@50: 0.3724
Mean Average Precision: 0.4547
