In [1]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download CIFAR-10
cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Download CIFAR-100
cifar100_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Dataloaders (optional, if needed for training/testing)
train_loader = DataLoader(cifar10_train, batch_size=64, shuffle=True)
test_loader = DataLoader(cifar10_test, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29493115.08it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:08<00:00, 20734464.76it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [2]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm

# Load pre-trained ResNet model
resnet = models.resnet18(pretrained=True)
resnet.fc = torch.nn.Identity()  # Remove the final classification layer to get features

# Ensure model is in evaluation mode
resnet.eval()

# CIFAR transformations
transform = transforms.Compose([
    transforms.Resize(224),  # ResNet expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to extract features
def extract_features(dataset, model, batch_size=64):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    features = []
    with torch.no_grad():
        for inputs, _ in tqdm(dataloader):
            outputs = model(inputs)
            features.append(outputs.cpu().numpy())
    return np.vstack(features)

# Extract features for training and test set
#cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
#cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)

train_features = extract_features(cifar10_train, resnet)
test_features = extract_features(cifar10_test, resnet)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 212MB/s]
100%|██████████| 782/782 [00:44<00:00, 17.39it/s]
100%|██████████| 157/157 [00:09<00:00, 16.89it/s]


In [3]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize

# Set number of clusters
K = 10  # You can experiment with different values

# Perform K-means clustering
kmeans = KMeans(n_clusters=K, random_state=0).fit(train_features)

# Predict clusters for test set
test_clusters = kmeans.predict(test_features)

# Get cluster centers
cluster_centers = kmeans.cluster_centers_

# Normalize features for cosine similarity calculation
train_features_normalized = normalize(train_features)
test_features_normalized = normalize(test_features)
cluster_centers_normalized = normalize(cluster_centers)

# Function to retrieve top-N similar images for a query image
def retrieve_similar_images(query_index, top_n=50):
    query_cluster = test_clusters[query_index]
    cluster_members = np.where(kmeans.labels_ == query_cluster)[0]  # Find images in the same cluster
    similarities = cosine_similarity([test_features_normalized[query_index]], train_features_normalized[cluster_members])[0]
    top_n_indices = cluster_members[np.argsort(similarities)[-top_n:]]
    return top_n_indices, similarities[np.argsort(similarities)[-top_n:]]



In [10]:
# Example query
query_index = 0  # Index of a test image
top_n_indices, similarities = retrieve_similar_images(query_index, 50)

print(f"Top {len(top_n_indices)} similar images to query index {query_index}: {top_n_indices}")

Top 50 similar images to query index 0: [11787 18908 39549  8001   241 12849 47277 17320  5450   821 30938 26558
 32371  8093 46448 36613 17776 12256 20107 28058 48170  6434 20963 40546
 12424 22407  2447  6790 46403 46229 41908 14960  6851 39861 29549 47493
 18413 20634 43476 45949 38820 24891 20641 48683 21275  7286 42522 19609
 35694 39972]


In [12]:
correct = cifar10_test[query_index][1]
cnt = 0
for idx in top_n_indices:
    if cifar10_train[idx][1] == correct:
        cnt += 1
print(cnt)

14
