In [15]:
import metrics

def calculate_avg_precision(gt, pred):
    """
    Compute the Average Precision (AP) based on the given formula.
    
    Args:
    gt (list): Ground truth list of relevant items (documents).
    pred (list): Predicted list of items (documents), sorted by relevance score.
    
    Returns:
    float: The Average Precision score.
    """
    if not gt:
        return 0.0
    
    score = 0.0
    num_hits = 0.0
    relevant_items = set(gt)  # To check relevance
    # total_relevant = len(relevant_items)  # Number of relevant items
    
    for i, p in enumerate(pred):
        if p in relevant_items:
            num_hits += 1
            precision_at_i = num_hits / (i + 1)  # Precision@i
            score += precision_at_i  # Add Precision@i to the score
    
    # Normalize by the total number of relevant items
    # return score / total_relevant if total_relevant > 0 else 0.0
    return score / len(gt) 

def calculate_avg_recall(gt, pred):
    """
    Compute the Average Precision (AP) based on the given formula.
    
    Args:
    gt (list): Ground truth list of relevant items (documents).
    pred (list): Predicted list of items (documents), sorted by relevance score.
    
    Returns:
    float: The Average Precision score.
    """
    if not gt:
        return 0.0
    
    score = 0.0
    num_hits = 0.0
    relevant_items = set(gt)  # To check relevance
    # total_relevant = len(relevant_items)  # Number of relevant items
    times_relevant = 0
    
    for i, p in enumerate(pred):
        if p in relevant_items:
            num_hits += 1
            precision_at_i = num_hits / (i + 1)  # Precision@i
            score += precision_at_i  # Add Precision@i to the score
            times_relevant += 1
    
    # Normalize by the total number of relevant items
    # return score / total_relevant if total_relevant > 0 else 0.0
    return score / times_relevant



In [31]:
pred = [1, 30, 5, 900, 2000, 300,]
gt = [1, 2, 3, 4, 5, 6]

print(calculate_avg_precision(gt, pred))
print(metrics._apk(gt, pred))
print(metrics._ark(gt, pred))
print(calculate_avg_recall(gt, pred))

0.27777777777777773
0.8333333333333333
0.27777777777777773
0.8333333333333333


In [2]:
import numpy as np
from heapq import heappush, heappop, nlargest
from typing import List, Dict, Tuple
import random

def original_method(candidates: List[str], all_distances: Dict[str, List[List[Tuple[float, int]]]], pk: int) -> Tuple[List[str], List[float]]:
    candidates = list(candidates)
    final_distances = [0] * len(candidates)
    for i, img in enumerate(candidates):
        for j in range(len(all_distances[img])):
            dist, _ = all_distances[img][j][0]
            final_distances[i] += dist

    sorted_indices = np.argsort(final_distances)
    sorted_imgs = [candidates[i] for i in sorted_indices]
    sorted_distances = [final_distances[i] for i in sorted_indices]
    return sorted_imgs[:pk], sorted_distances[:pk]

def heap_method(candidates: List[str], all_distances: Dict[str, List[List[Tuple[float, int]]]], pk: int) -> Tuple[List[str], List[float]]:
    img_distances = []
    for img in candidates:
        distance = 0
        for j in range(len(all_distances[img])):
            dist, _ = all_distances[img][j][0]
            distance += dist
        heappush(img_distances, (-distance, img))

        if len(img_distances) > pk:
            heappop(img_distances)

    result = nlargest(pk, img_distances)
    return [img for _, img in result], [-dist for dist, _ in result]

def generate_random_data(num_candidates: int, num_queries: int) -> Tuple[List[str], Dict[str, List[List[Tuple[float, int]]]]]:
    candidates = [f'img{i}' for i in range(num_candidates)]
    all_distances = {}
    
    for img in candidates:
        all_distances[img] = []
        for _ in range(num_queries):
            distances = [(random.uniform(0, 10), 0) for _ in range(random.randint(1, 5))]
            all_distances[img].append(distances)
    
    return candidates, all_distances

def run_test(num_candidates: int, num_queries: int, pk: int) -> bool:
    candidates, all_distances = generate_random_data(num_candidates, num_queries)
    
    original_result = original_method(candidates, all_distances, pk)
    heap_result = heap_method(candidates, all_distances, pk)

    print(original_result)
    print(heap_result)
    
    images_match = original_result[0] == heap_result[0]
    distances_match = np.allclose(original_result[1], heap_result[1], rtol=1e-5, atol=1e-8)
    
    return images_match and distances_match

# Run multiple tests with different parameters
test_cases = [
    (10, 5, 3),    # Small case
    (100, 10, 20), # Medium case
    (1000, 20, 50), # Large case
    (5, 3, 10),    # pk > num_candidates
    (1, 1, 1),     # Single candidate, single query
    (0, 0, 0),     # Edge case: empty input
]

for i, (num_candidates, num_queries, pk) in enumerate(test_cases):
    result = run_test(num_candidates, num_queries, pk)
    print(f"Test case {i+1} ({'✓' if result else '✗'}): {num_candidates} candidates, {num_queries} queries, pk={pk}")

# Run a large number of random tests
num_random_tests = 100
random_test_results = [run_test(random.randint(1, 1000), random.randint(1, 20), random.randint(1, 50)) for _ in range(num_random_tests)]
print(f"\nRandom tests: {sum(random_test_results)} / {num_random_tests} passed")

(['img0', 'img9', 'img5'], [14.733169328322228, 17.98338533096411, 19.412110644679924])
(['img0', 'img9', 'img5'], [14.733169328322228, 17.98338533096411, 19.412110644679924])
Test case 1 (✓): 10 candidates, 5 queries, pk=3
(['img4', 'img5', 'img30', 'img61', 'img36', 'img87', 'img71', 'img21', 'img43', 'img24', 'img26', 'img8', 'img28', 'img86', 'img16', 'img42', 'img78', 'img58', 'img60', 'img12'], [30.467763150112457, 30.51904092971643, 31.882131164920924, 32.294256674950546, 33.27614350329782, 33.819387840972645, 34.845477138645684, 35.53393152244784, 36.26124707613847, 36.45475552965617, 36.840188080678814, 37.079767361754236, 37.1531280556785, 37.24544434201456, 38.35487791587841, 39.23607338171355, 39.44706965305382, 39.49259400307493, 39.63510815925888, 39.79907878396241])
(['img4', 'img5', 'img30', 'img61', 'img36', 'img87', 'img71', 'img21', 'img43', 'img24', 'img26', 'img8', 'img28', 'img86', 'img16', 'img42', 'img78', 'img58', 'img60', 'img12'], [30.467763150112457, 30.5190