In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

# Define transforms for data augmentation and normalization
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(), 
                                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                               ])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
num_labels = 100

device = 'cuda'

# Create DataLoader for batch processing
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1000, shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 48986823.59it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [2]:
import torch.nn as nn
import torchvision.models as models
import numpy as np
from tqdm import tqdm

# Load pretrained ResNet model and modify it to act as a feature extractor
resnet50 = models.resnet50(pretrained=True)
resnet50 = nn.Sequential(*list(resnet50.children())[:-1])  # Remove the final classification layer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50 = resnet50.to(device)

# Function to extract features
def extract_features(dataloader, dataset_name='Dataset'):
    resnet50.eval()  # Set model to evaluation mode
    features = []
    labels = []
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc=f'Extracting features from {dataset_name}', unit='batch', total=len(dataloader)):
            outputs = resnet50(inputs.to(device)).squeeze()
            features.append(outputs.cpu())
            labels.append(targets)
    return torch.vstack(features), torch.hstack(labels)



# Extract features from train and test set
train_features, train_labels = extract_features(trainloader, 'Train Set')
test_features, test_labels = extract_features(testloader, 'Test Set')

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 160MB/s] 
Extracting features from Train Set: 100%|██████████| 50/50 [01:26<00:00,  1.72s/batch]
Extracting features from Test Set: 100%|██████████| 10/10 [00:18<00:00,  1.89s/batch]


In [3]:
from torch.utils.data import Dataset, DataLoader

batch_size = 125

def sample_negatives(train_features, train_labels, num_negatives = 10):
    class CustomDataset(Dataset):
        def __init__(self, features, labels):
            self.features = features
            self.labels = labels

        def __len__(self):
            return len(self.labels)

        def __getitem__(self, idx):
            return self.features[idx], self.labels[idx]

    dataset = CustomDataset(train_features, train_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    neg_indices = torch.empty((train_features.shape[0], num_negatives), dtype=torch.int)
    
    with tqdm(total=train_features.shape[0], desc='Sampling Negatives') as pbar:
        for idx, (batch_features, batch_labels) in enumerate(dataloader):
            batch_size_current = batch_features.shape[0]  # Get current batch size
            all_indices = torch.arange(batch_size_current)
            
            for i in range(batch_size_current):
                label = batch_labels[i].item()
                neg_mask = batch_labels != label
                
                assert neg_mask.sum() >= 10, "does not has enough negatives"
                
                neg_candidates = all_indices[neg_mask]
                neg_indices[i + idx * batch_size] = neg_candidates[torch.randperm(len(neg_candidates))[:num_negatives]]
                
                pbar.update(1)
    
    return neg_indices

neg_indices = sample_negatives(train_features, train_labels)

Sampling Negatives: 100%|██████████| 50000/50000 [00:03<00:00, 12918.81it/s]


In [4]:
from torch.utils.data import Dataset, DataLoader

class TrainDataset(Dataset):
    def __init__(self, features, neg_indices):
        self.features = features
        self.neg_indices = neg_indices
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.neg_indices[idx]
    
train_dataset = TrainDataset(train_features, neg_indices)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [5]:
import random

class NeuralLSH(nn.Module):
    def __init__(self, input_dim, hash_dim, num_tables, subset_size):
        super(NeuralLSH, self).__init__()
        self.input_dim = input_dim
        self.hash_dim = hash_dim
        self.num_tables = num_tables
        self.subset_size = subset_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.powers_of_two = torch.pow(2, torch.arange(subset_size - 1, -1, -1)).float().to(device)
        self.zero = torch.tensor([0], device=self.device)

        self.init_hash_functions()
        self.hyperplanes = nn.Parameter(torch.randn(self.input_dim, self.hash_dim, device=self.device))
        
    def init_hash_functions(self):
        self.hash_functions = torch.tensor([], device=self.device).long()
        indices = list(range(self.hash_dim))
        for _ in range(self.num_tables):
            random.shuffle(indices)
            self.hash_functions = torch.cat((self.hash_functions,
                torch.tensor([indices[:self.subset_size]], device=self.device).long()), dim=0)
    
    def _projection(self, features):
        return torch.mm(features, self.hyperplanes)
    
    def forward(self, features):
        return torch.tanh(self._projection(features))
    
    
    def init_hash_tables(self, train_features):
        train_features = train_features.to(self.device)
        full_hash_codes = self._projection(train_features)
        self.hash_tables = []

        full_hash_values = torch.transpose(((full_hash_codes[:, self.hash_functions] > 0).float() @ self.powers_of_two).int(), 0, 1)
        for table in range(self.num_tables):
            self.hash_tables.append([])
            for hash_val in range(2 ** self.subset_size):
                self.hash_tables[table].append(torch.nonzero(full_hash_values[table] == hash_val).T[0].tolist())      

    def get_corpus_indices(self, features):
        features = features.to(device)
        full_hash_codes = self._projection(features)
        
        full_hash_values = ((full_hash_codes[:, self.hash_functions] > 0).float() @ self.powers_of_two).int()
                
        corpus_indices = []
        for hash_values in tqdm(full_hash_values, desc='Creating Corpus for Test Image', total=len(full_hash_values)):
            indices = set()
            for hash_table, hash_val in zip(self.hash_tables, hash_values):
                indices.update(hash_table[hash_val.item()])
            
            corpus_indices.append(list(indices))
        
        return corpus_indices
        

In [6]:
def loss_func(hash_codes, neg_indices):
    # taking alpha = beta = gamma = 1/3
    term1 = torch.sum(torch.abs(torch.sum(hash_codes, dim=1))) / hash_codes.shape[0]
    
    term2 = torch.sum(torch.abs(torch.abs(hash_codes) - torch.ones(hash_codes.shape[1], device=device))) / hash_codes.shape[0]
        
    negs = torch.transpose(hash_codes[neg_indices], 1, 2)
    term3 = torch.sum(torch.abs(torch.matmul(hash_codes.unsqueeze(1), negs))) / (neg_indices.shape[0] * neg_indices.shape[1])

    return (term1 + term2 + term3) / 3

def train_model(train_dataloader, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()
    for epoch in range(epochs):     
        total_loss = 0
        
        for train_data, neg_idx in train_dataloader:
            train_data = train_data.to(device)
            neg_idx = neg_idx.to(device)

            hash_codes = model(train_data)
            loss = loss_func(hash_codes, neg_idx)
            total_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        
        if (epoch % 5 == 4 or epoch == epochs - 1):
            print(f'Epoch {epoch+1}, Average Loss: {total_loss / len(train_dataloader)}')
        

In [7]:
from sklearn.metrics.pairwise import cosine_similarity

def get_top_k_matches(train_features, test_feature, cluster_indices, k=50):
    # Get the features of images belonging to the same cluster
    cluster_features = train_features[cluster_indices]

    # Compute cosine similarity between the test image and cluster images
    similarities = cosine_similarity(test_feature.reshape(1, -1), cluster_features).flatten()
    
    # Get the top k most similar images
    top_k_indices = np.argsort(similarities)[::-1][:k]
    return cluster_indices[top_k_indices]

In [8]:
def precision_at_k(true_label, top_k_labels, k):
    top_k = top_k_labels[:k]
    correct = torch.sum(top_k == true_label).item()
    return correct / k

def mean_average_precision(true_label, top_k_labels):
    # Calculate precision at each rank and then compute average precision
    precisions = []
    correct = 0
    for i, label in enumerate(top_k_labels):
        if label == true_label:
            correct += 1
            precisions.append(correct / (i + 1))
    return np.mean(precisions) if precisions else 0

In [9]:
def get_top_matches(train_features, test_features, corpus_indices):
    # For each test image, find the top 50 matches
    top_k_matches = []
    for i, test_feature in tqdm(enumerate(test_features), total=len(test_features)):
        cluster_indices = np.array(corpus_indices[i])

        if (len(cluster_indices) == 0):
            top_k_matches.append([])
            continue
        
        # Get the top 50 matches based on cosine similarity
        top_k_matches.append(get_top_k_matches(train_features, test_feature, cluster_indices))

    return top_k_matches

In [10]:
def evaluate(train_labels, test_labels, top_k_matches):
    # Evaluate for all test images
    precision_10 = []
    precision_50 = []
    mean_ap = []
    for i, matches in enumerate(top_k_matches):
    # for i, matches in tqdm(enumerate(top_k_matches), desc='Evaluating Metrics', unit='image', total=len(top_k_matches)):
        true_label = test_labels[i]
        matched_labels = train_labels[matches]
        
        precision_10.append(precision_at_k(true_label, matched_labels, 10))
        precision_50.append(precision_at_k(true_label, matched_labels, 50))
        mean_ap.append(mean_average_precision(true_label, matched_labels))

    return np.mean(precision_10), np.mean(precision_50), np.mean(mean_ap)

In [11]:
num_features = train_features.shape[1]
hash_dim = 16
num_tables = 10
subset_size = 8

In [12]:
import torch.optim as optim

model = NeuralLSH(num_features, hash_dim, num_tables, subset_size)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_model(train_dataloader, model, optimizer, 40)

Epoch 5, Average Loss: 1.8775205612182617
Epoch 10, Average Loss: 1.8193511962890625
Epoch 15, Average Loss: 1.7734841108322144
Epoch 20, Average Loss: 1.7543021440505981
Epoch 25, Average Loss: 1.7264974117279053
Epoch 30, Average Loss: 1.7076013088226318
Epoch 35, Average Loss: 1.6904443502426147
Epoch 40, Average Loss: 1.6748626232147217


In [13]:
model.init_hash_tables(train_features)

In [14]:
corpus_indices = model.get_corpus_indices(test_features)

Creating Corpus for Test Image: 100%|██████████| 10000/10000 [00:04<00:00, 2338.58it/s]


In [15]:
top_matches = get_top_matches(train_features, test_features, corpus_indices)

100%|██████████| 10000/10000 [05:54<00:00, 28.18it/s]


In [16]:
precision_10, precision_50, mean_ap = evaluate(train_labels, test_labels, top_matches)

print(f'Mean Precision@10: {precision_10:.4f}')
print(f'Mean Precision@50: {precision_50:.4f}')
print(f'Mean Average Precision: {mean_ap:.4f}')

Mean Precision@10: 0.4610
Mean Precision@50: 0.3499
Mean Average Precision: 0.4934
