a small code to perform a sanity check/correctness check on the evaluation function

In [1]:
# preliminary code

# enable hot reload of imported code
%load_ext autoreload
%autoreload 2

In [2]:
# enable src imports if needed (setting an env variable may be enough too)
import sys
sys.path.append("..")

In [3]:
from sklearn.metrics import ndcg_score, average_precision_score
import torch
from src.evaluation.evaluation_utils import evaluate

In [4]:
import torchmetrics
import numpy as np

In [5]:
# generate test set

n = 1000
class_num = 8
emb_size = 16

no_class_v = torch.arange(n)
class_v = (torch.rand(n)*class_num).floor().int()

true_similarity = torch.eye(n)
true_similarity_class = (class_v == class_v.unsqueeze(1)).int()


scenes = torch.rand((n,emb_size))
descs = torch.rand((n,emb_size))

# # perfect match:
# assert emb_size >= n
# scenes = torch.zeros((n,emb_size))
# for i in range(n):
#     scenes[i,i] =1
# descs = scenes

similarity_matrix = torch.nn.functional.cosine_similarity(
        scenes.unsqueeze(1), descs.unsqueeze(0), dim=-1
    ) 

eval_res = evaluate(descs, scenes, class_v, class_v)

In [6]:
# proxy to use the torchmetrics implementation of the precision and recall taking the average over the all examples
def recall_precision_at_k(similarity_mat, true_similarity, k = None):
    retrieval_results = []
    precision_results = []
    for i in range(len(similarity_mat)):
        retrieval_results.append(torchmetrics.functional.retrieval.retrieval_recall(similarity_mat[i], true_similarity[i], top_k=k))
        precision_results.append(torchmetrics.functional.retrieval.retrieval_precision(similarity_mat[i], true_similarity[i], top_k=k))
    return 100*np.average(retrieval_results), 100*np.average(precision_results)

In [7]:
max_error = 0.01 # this is expressed in percentages

## recall and precision
value, _ = recall_precision_at_k(similarity_matrix, true_similarity, k=1)
assert (value - eval_res['s2t_R@1']) <= max_error, f"{value} vs {eval_res['s2t_R@1']}"

value, _ = recall_precision_at_k(similarity_matrix, true_similarity, k=5)
assert (value - eval_res['s2t_R@5']) <= max_error, f"{value} vs {eval_res['s2t_R@5']}"

value, _ = recall_precision_at_k(similarity_matrix, true_similarity, k=10)
assert (value - eval_res['s2t_R@10']) <= max_error, f"{value} vs {eval_res['s2t_R@10']}"

value, value_p = recall_precision_at_k(similarity_matrix, true_similarity_class, k=1)
assert (value - eval_res['s2t_class_R@1']) <= max_error, f"{value} vs {eval_res['s2t_class_R@1']}"
assert (value_p - eval_res['s2t_class_P@1']) <= max_error, f"{value_p} vs {eval_res['s2t_class_P@1']}"

value, value_p = recall_precision_at_k(similarity_matrix, true_similarity_class, k=5)
assert (value - eval_res['s2t_class_R@5']) <= max_error, f"{value} vs {eval_res['s2t_class_R@5']}"
assert (value_p - eval_res['s2t_class_P@5']) <= max_error, f"{value_p} vs {eval_res['s2t_class_P@5']}"

value, value_p = recall_precision_at_k(similarity_matrix, true_similarity_class, k=10)
assert (value - eval_res['s2t_class_R@10']) <= max_error, f"{value} vs {eval_res['s2t_class_R@10']}"
assert (value_p - eval_res['s2t_class_P@10']) <= max_error, f"{value_p} vs {eval_res['s2t_class_P@10']}"

value, _ = recall_precision_at_k(similarity_matrix.T, true_similarity.T, k=1)
assert (value - eval_res['t2s_R@1']) <= max_error, f"{value} vs {eval_res['t2s_R@1']}"

value, _ = recall_precision_at_k(similarity_matrix.T, true_similarity.T, k=5)
assert (value - eval_res['t2s_R@5']) <= max_error, f"{value} vs {eval_res['t2s_R@5']}"

value, _ = recall_precision_at_k(similarity_matrix.T, true_similarity.T, k=10)
assert (value - eval_res['t2s_R@10']) <= max_error, f"{value} vs {eval_res['t2s_R@10']}"

value, value_p = recall_precision_at_k(similarity_matrix.T, true_similarity_class.T, k=1)
assert (value - eval_res['t2s_class_R@1']) <= max_error, f"{value} vs {eval_res['t2s_class_R@1']}"
assert (value_p - eval_res['t2s_class_P@1']) <= max_error, f"{value_p} vs {eval_res['t2s_class_P@1']}"

value, value_p = recall_precision_at_k(similarity_matrix.T, true_similarity_class.T, k=5)
assert (value - eval_res['t2s_class_R@5']) <= max_error, f"{value} vs {eval_res['t2s_class_R@5']}"
assert (value_p - eval_res['t2s_class_P@5']) <= max_error, f"{value_p} vs {eval_res['t2s_class_P@5']}"

value, value_p = recall_precision_at_k(similarity_matrix.T, true_similarity_class.T, k=10)
assert (value - eval_res['t2s_class_R@10']) <= max_error, f"{value} vs {eval_res['t2s_class_R@10']}"
assert (value_p - eval_res['t2s_class_P@10']) <= max_error, f"{value_p} vs {eval_res['t2s_class_P@10']}"


## NDCG
value = 100*ndcg_score(true_similarity, similarity_matrix)
assert (value - eval_res['s2t_avg_ndcg']) <= max_error, f"{value} vs {eval_res['s2t_avg_ndcg']}"

value = 100*ndcg_score(true_similarity_class, similarity_matrix)
assert (value - eval_res['s2t_avg_ndcg_by_class']) <= max_error, f"{value} vs {eval_res['s2t_avg_ndcg_by_class']}"

value = 100*ndcg_score(true_similarity.T, similarity_matrix.T)
assert (value - eval_res['t2s_avg_ndcg']) <= max_error, f"{value} vs {eval_res['t2s_avg_ndcg']}"

value = 100*ndcg_score(true_similarity_class.T, similarity_matrix.T)
assert (value - eval_res['t2s_avg_ndcg_by_class']) <= max_error, f"{value} vs {eval_res['t2s_avg_ndcg_by_class']}"

## mAP
# here average samples is needed to get the correct association, default with 'macro' would transpose rows and columns
value = 100*average_precision_score(true_similarity, similarity_matrix, average='samples')
assert (value - eval_res['s2t_mAP']) <= max_error, f"{value} vs {eval_res['s2t_mAP']}"

value = 100*average_precision_score(true_similarity_class, similarity_matrix, average='samples')
assert (value - eval_res['s2t_mAP_by_class']) <= max_error, f"{value} vs {eval_res['s2t_mAP_by_class']}"

value = 100*average_precision_score(true_similarity.T, similarity_matrix.T, average='samples')
assert (value - eval_res['t2s_mAP']) <= max_error, f"{value} vs {eval_res['t2s_mAP']}"

value = 100*average_precision_score(true_similarity_class.T, similarity_matrix.T, average='samples')
assert (value - eval_res['t2s_mAP_by_class']) <= max_error, f"{value} vs {eval_res['t2s_mAP_by_class']}"

print("all test passed - the code author appears to be mentally sane :)")


all test passed - the code author appears to be mentally sane :)
