In [1]:
import torch
import numpy as np

def load_embeddings(file_path):
    return torch.load(file_path).cpu().numpy()

In [2]:
def Cosine_distance(text_sample, train_sample):
    
    dot_product = np.dot(text_sample, train_sample)  # Dot product of the vectors
    norm_text = np.linalg.norm(text_sample)  # Norm of the text sample
    norm_train = np.linalg.norm(train_sample)  # Norm of the train sample
    
    if norm_text == 0 or norm_train == 0:
        return 1.0

    cosine_similarity = dot_product / (norm_text * norm_train)
    
    cosine_distance = 1 - cosine_similarity
    return cosine_distance

In [3]:
def find_k_nearest_neighbors(text_embedding, train_embeddings, k):
    distances = []
    
    for idx, train_embedding in enumerate(train_embeddings):
        dist = Cosine_distance(text_embedding, train_embedding)
        distances.append((dist, idx))

    distances.sort(key=lambda x: x[0])
    
    return [idx for _, idx in distances[:k]]

In [4]:
def mean_reciprocal_rank_1(nearest_neighbors, train_labels):
    mrr = 0.0
    for i in range(len(nearest_neighbors)):
        true_label = i

        neighbors = nearest_neighbors[i]
        for rank, neighbor_idx in enumerate(neighbors):
            if train_labels[neighbor_idx] == true_label:
                mrr += 1 / (rank + 1)  
                break
        
    return mrr / len(nearest_neighbors)

def precision_at_k_1(nearest_neighbors, train_labels, k=100):
    precision = 0.0
    for i in range(len(nearest_neighbors)):
        true_label = i
        neighbors = nearest_neighbors[i]

        relevant_count = 0
        for neighbor_idx in neighbors[:k]:
            if train_labels[neighbor_idx] == true_label:
                relevant_count += 1
        precision += relevant_count / k
        
    return precision / len(nearest_neighbors)

def hit_rate_at_k_1(nearest_neighbors, train_labels, k=100):
    hits = 0
    for i in range(len(nearest_neighbors)):
        true_label = i
        neighbors = nearest_neighbors[i]

        if any(train_labels[neighbor_idx] == true_label for neighbor_idx in neighbors[:k]):
            hits += 1
        
    return hits / len(nearest_neighbors) 

In [5]:
def mean_reciprocal_rank_2(nearest_neighbors, train_labels, test_labels):
    mrr = 0.0
    for i in range(len(nearest_neighbors)):
        true_label = test_labels[i]

        neighbors = nearest_neighbors[i]

        for rank, neighbor_idx in enumerate(neighbors):
            if train_labels[neighbor_idx] == true_label:
                mrr += 1 / (rank + 1)
                break
        
    return mrr / len(nearest_neighbors)


def precision_at_k_2(nearest_neighbors, train_labels, test_labels, k=100):
    precision = 0.0
    for i in range(len(nearest_neighbors)):
        true_label = test_labels[i] 
        neighbors = nearest_neighbors[i]

        relevant_count = 0
        for neighbor_idx in neighbors[:k]:
            if train_labels[neighbor_idx] == true_label:
                relevant_count += 1
        precision += relevant_count / k
        
    return precision / len(nearest_neighbors)

def hit_rate_at_k_2(nearest_neighbors, train_labels, test_labels,k=100):
    hits = 0
    for i in range(len(nearest_neighbors)):
        true_label = test_labels[i] 
        neighbors = nearest_neighbors[i]

        if any(train_labels[neighbor_idx] == true_label for neighbor_idx in neighbors[:k]):
            hits += 1
        
    return hits / len(nearest_neighbors)

In [6]:
train_embeddings = load_embeddings('SMAI A1-20250202T180732Z-001/SMAI A1/train_embeddings.pth')
text_embeddings = load_embeddings('SMAI A1-20250202T180732Z-001/SMAI A1/text_embedding.pth')
train_labels = load_embeddings('SMAI A1-20250202T180732Z-001/SMAI A1/train_labels.pth')
test_embeddings = load_embeddings('SMAI A1-20250202T180732Z-001/SMAI A1/test_embeddings.pth')
test_labels = load_embeddings('SMAI A1-20250202T180732Z-001/SMAI A1/test_labels.pth')

In [7]:
k = 100
nearest_neighbors = []

for text in text_embeddings:
    nearest_embeds = find_k_nearest_neighbors(text, train_embeddings, k)
    nearest_neighbors.append(nearest_embeds)

mrr = mean_reciprocal_rank_1(nearest_neighbors, train_labels)
print(f"MRR: {mrr}")
precision = precision_at_k_1(nearest_neighbors, train_labels, k=100)
hit_rate = hit_rate_at_k_1(nearest_neighbors, train_labels, k=100)

print(f"Precision@100: {precision}")
print(f"Hit Rate@100: {hit_rate}")

MRR: 1.0
Precision@100: 0.974
Hit Rate@100: 1.0


In [8]:
k = 100
nearest_neighbors = []

for test_point in test_embeddings:
    nearest_embeds = find_k_nearest_neighbors(test_point, train_embeddings, k)
    nearest_neighbors.append(nearest_embeds)

mrr = mean_reciprocal_rank_2(nearest_neighbors, train_labels, test_labels) 
print(f"MRR: {mrr}")
precision = precision_at_k_2(nearest_neighbors, train_labels, test_labels, k=100)
hit_rate = hit_rate_at_k_2(nearest_neighbors, train_labels, test_labels, k=100)

print(f"Precision@100: {precision}")
print(f"Hit Rate@100: {hit_rate}")

MRR: 0.9347961513315047
Precision@100: 0.8410819999999664
Hit Rate@100: 0.9996


In [8]:
def average_comparisons(test_embeddings, train_embeddings):
    total_comparisons = 0
    for test_point in test_embeddings:
        # For each test point, we compare it to every train sample (this is already done in find_k_nearest_neighbors)
        total_comparisons += len(train_embeddings)
        
    avg_comparisons = total_comparisons / len(test_embeddings)
    return avg_comparisons

# Calculate the average number of comparisons for the queries
avg_comparisons = average_comparisons(test_embeddings, train_embeddings)
print(f"Average number of comparisons per query: {avg_comparisons}")


Average number of comparisons per query: 50000.0
