# Прогон тестов для метрики TPR@FPR

Здесь представлена черновая реализация метрики TPR@FPR для проверки корректности тестов.

.∧＿∧ 
( ･ω･｡)つ━☆・*。 
⊂　 ノ 　　　・゜+. 
しーＪ　　　°。+ *´¨) 
　　　　　　　　　.· ´¸.·*´¨) ¸.·*¨) 
　　　　　　　　　　(¸.·´ (¸.·'* ☆Итоговую версию метрики можно найти перейдя по импорту ниже

In [ ]:
from metrics.tprfpr import compute_ir_metric

In [1]:
from torch.nn  import CosineSimilarity
import torch
import numpy as np

In [2]:
# Если вы работаете с данными, которые даны по ссылке,
# то эта ячейка поможет их загрузить
import os
from collections import defaultdict

# file with query part annotations: which image belongs to which class
# format:
#     image_name_1.jpg 2678
#     image_name_2.jpg 2679
f = open('./celebA_ir/celebA_anno_query.csv', 'r')
query_lines = f.readlines()[1:]
f.close()
query_lines = [x.strip().split(',') for x in query_lines]
# plain list of image names from query. Neede to compute embeddings for query
query_img_names = [x[0] for x in query_lines]

# dictionary with info of which images from query belong to which class
# format:
#     {class: [image_1, image_2, ...]}
query_dict = defaultdict(list)
for img_name, img_class in query_lines:
    query_dict[img_class].append(img_name)

# list of distractor images
distractors_img_names = os.listdir('./celebA_ir/celebA_distractors')

In [5]:
def compute_cosine_query_pos(query_dict, query_img_names, query_embeddings):
    '''
    compute cosine similarities between positive pairs from query (stage 1)
    params:
      query_dict: dict {class: [image_name_1, image_name_2, ...]}. Key: class in
                  the dataset. Value: images corresponding to that class
      query_img_names: list of images names
      query_embeddings: list of embeddings corresponding to query_img_names
    output:
      list of floats: similarities between embeddings corresponding
                      to the same people from query list
    '''
    cosine_similarity = CosineSimilarity(dim=0)
    pos_similarities = []
    tensor = torch.FloatTensor(query_embeddings)
    for _, key in enumerate(query_dict):
        imgs = query_dict[key]
        for i, img_i in enumerate(imgs):
            for j, img_j in enumerate(imgs):
                if i < j:
                    emb_index_i = query_img_names.index(img_i)
                    emb_index_j =  query_img_names.index(img_j)
                    tensor_i = tensor[emb_index_i]
                    tensor_j = tensor[emb_index_j]
                    similarity = cosine_similarity(tensor_i, tensor_j).item()
                    pos_similarities.append(similarity)
    return pos_similarities                
                    

def compute_cosine_query_neg(query_dict, query_img_names, query_embeddings):
    '''
    compute cosine similarities between negative pairs from query (stage 2)
    params:
      query_dict: dict {class: [image_name_1, image_name_2, ...]}. Key: class in
                  the dataset. Value: images corresponding to that class
      query_img_names: list of images names
      query_embeddings: list of embeddings corresponding to query_img_names
    output:
      list of floats: similarities between embeddings corresponding
                      to different people from query list
    '''
    cosine_similarity = CosineSimilarity(dim=0)
    neg_similarities = []
    tensor = torch.FloatTensor(query_embeddings)
    for _, key_i in enumerate(query_dict):
        for _, key_j in enumerate(query_dict):
            if key_i < key_j:
                list_i = query_dict[key_j]
                list_j = query_dict[key_i]
                for img_i in list_i:
                    for img_j in list_j:
                        indx_i = query_img_names.index(img_i)
                        indx_j = query_img_names.index(img_j)
                        tensor_i = tensor[indx_i]
                        tensor_j = tensor[indx_j]
                        similarity = cosine_similarity(tensor_i, tensor_j).item()
                        neg_similarities.append(similarity)
    return neg_similarities                    
                        
    

def compute_cosine_query_distractors(query_embeddings, distractors_embeddings):
    '''
    compute cosine similarities between negative pairs from query and distractors
    (stage 3)
    params:
      query_embeddings: list of embeddings corresponding to query_img_names
      distractors_embeddings: list of embeddings corresponding to distractors_img_names
    output:
      list of floats: similarities between pairs of people (q, d), where q is
                      embedding corresponding to photo from query, d —
                      embedding corresponding to photo from distractors
    '''
    cosine_similarity = CosineSimilarity(dim=0)
    similarities = []
    for i_emb in query_embeddings:
        for j_emb in distractors_embeddings:
            tensor_i = torch.FloatTensor(i_emb)
            tensor_j = torch.FloatTensor(j_emb)
            similarity = cosine_similarity(tensor_i, tensor_j).item()
            similarities.append(similarity)
    return similarities       

In [6]:
test_query_dict = {
    2876: ['1.jpg', '2.jpg', '3.jpg'],
    5674: ['5.jpg'],
    864:  ['9.jpg', '10.jpg'],
}
test_query_img_names = ['1.jpg', '2.jpg', '3.jpg', '5.jpg', '9.jpg', '10.jpg']
test_query_embeddings = [
    [1.56, 6.45,  -7.68],
    [-1.1 , 6.11,  -3.0],
    [-0.06,-0.98,-1.29],
    [8.56, 1.45,  1.11],
    [0.7,  1.1,   -7.56],
    [0.05, 0.9,   -2.56],
]

test_distractors_img_names = ['11.jpg', '12.jpg', '13.jpg', '14.jpg', '15.jpg']

test_distractors_embeddings = [
    [0.12, -3.23, -5.55],
    [-1,   -0.01, 1.22],
    [0.06, -0.23, 1.34],
    [-6.6, 1.45,  -1.45],
    [0.89,  1.98, 1.45],
]

test_cosine_query_pos = compute_cosine_query_pos(test_query_dict, test_query_img_names,
                                                 test_query_embeddings)
test_cosine_query_neg = compute_cosine_query_neg(test_query_dict, test_query_img_names,
                                                 test_query_embeddings)
test_cosine_query_distractors = compute_cosine_query_distractors(test_query_embeddings,
                                                                 test_distractors_embeddings)

In [7]:
true_cosine_query_pos = [0.8678237233650096, 0.21226104378511604,
                         -0.18355866977496182, 0.9787437979250561]
assert np.allclose(sorted(test_cosine_query_pos), sorted(true_cosine_query_pos)), \
    "A mistake in compute_cosine_query_pos function"

for i, v in enumerate(true_cosine_query_pos):
    print(v - test_cosine_query_pos[i])


true_cosine_query_neg = [0.15963231223161822, 0.8507997093616965, 0.9272761484302097,
                         -0.0643994061127092, 0.5412660901220571, 0.701307100338029,
                         -0.2372575528216902, 0.6941032794522218, 0.549425446066643,
                         -0.011982733001947084, -0.0466679194884999]
assert np.allclose(sorted(test_cosine_query_neg), sorted(true_cosine_query_neg)), \
    "A mistake in compute_cosine_query_neg function"


true_cosine_query_distractors = [0.3371426578637511, -0.6866465610863652, -0.8456563512871669,
                                 0.14530087113136106, 0.11410510307646118, -0.07265097629002357,
                                 -0.24097699660707042,-0.5851992679925766, 0.4295494455718534,
                                 0.37604478596058194, 0.9909483738948858, -0.5881093317868022,
                                 -0.6829712976642919, 0.07546364489032083, -0.9130970963915521,
                                 -0.17463101988684684, -0.5229363015558941, 0.1399896725311533,
                                 -0.9258034013399499, 0.5295114163723346, 0.7811585442749943,
                                 -0.8208760031249596, -0.9905139680301821, 0.14969764653247228,
                                 -0.40749654525418444, 0.648660814944824, -0.7432584300096284,
                                 -0.9839696492435877, 0.2498741082804709, -0.2661183373780491]
assert np.allclose(sorted(test_cosine_query_distractors), sorted(true_cosine_query_distractors)), \
    "A mistake in compute_cosine_query_distractors function"

-1.15822612500871e-07
5.245020087696339e-08
9.229855502113082e-08
-5.325978891246308e-08


In [8]:
def compute_ir(cosine_query_pos, cosine_query_neg, cosine_query_distractors, fpr=0.1):
    '''
    compute identification rate using precomputer cosine similarities between pairs
    at given fpr
    params:
      cosine_query_pos: cosine similarities between positive pairs from query
      cosine_query_neg: cosine similarities between negative pairs from query
      cosine_query_distractors: cosine similarities between negative pairs
                                from query and distractors
      fpr: false positive rate at which to compute TPR
    output:
      float: threshold for given fpr
      float: TPR at given FPR
    '''
    false_pairs = cosine_query_neg + cosine_query_distractors
    false_pairs = sorted(false_pairs, reverse=True)

    # Acceptable amount of false pairs
    N = int(fpr * len(false_pairs))

    threshold_similarity = false_pairs[N]
    print(threshold_similarity)

    metric_value: int = 0
    for s in cosine_query_pos:
        if s > threshold_similarity:
            metric_value += 1

    return threshold_similarity, metric_value / len(cosine_query_pos)


In [9]:
test_thr = []
test_tpr = []
for fpr in [0.5, 0.3, 0.1]:
    x, y = compute_ir(test_cosine_query_pos, test_cosine_query_neg,
                      test_cosine_query_distractors, fpr=fpr)
    test_thr.append(x)
    test_tpr.append(y)

-0.011982724070549011
0.337142676115036
0.7013071179389954


In [10]:
test_thr, test_tpr

([-0.011982724070549011, 0.337142676115036, 0.7013071179389954],
 [0.75, 0.5, 0.5])

In [11]:
true_thr = [-0.011982733001947084, 0.3371426578637511, 0.701307100338029]
assert np.allclose(np.array(test_thr), np.array(true_thr)), "A mistake in computing threshold"

true_tpr = [0.75, 0.5, 0.5]
assert np.allclose(np.array(test_tpr), np.array(true_tpr)), "A mistake in computing tpr"

In [12]:
test_thr = []
test_tpr = []
for fpr in [0.5, 0.2, 0.1, 0.05]:
    x, y = compute_ir(test_cosine_query_pos, test_cosine_query_neg,
                      test_cosine_query_distractors, fpr=fpr)
    test_thr.append(x)
    test_tpr.append(y)

-0.011982724070549011
0.5412660837173462
0.7013071179389954
0.8507997989654541
