In [16]:
import os

import torch
from collections import defaultdict

f = open('./celebA_ir/celebA_anno_query.csv', 'r')
query_lines = f.readlines()[1:]
f.close()
query_lines = [x.strip().split(',') for x in query_lines]
query_img_names = [x[0] for x in query_lines]

query_dict = defaultdict(list)

for img_name, img_class in query_lines:
    query_dict[img_class].append(f"./celebA_ir/celebA_query/{img_name}")

distractors_img_names = os.listdir('./celebA_ir/celebA_distractors')

In [120]:
distractors_img_path = [f"./celebA_ir/celebA_distractors/{x}" for x in distractors_img_names]
distractors_img_path

['./celebA_ir/celebA_distractors/141640.jpg',
 './celebA_ir/celebA_distractors/118735.jpg',
 './celebA_ir/celebA_distractors/116718.jpg',
 './celebA_ir/celebA_distractors/047543.jpg',
 './celebA_ir/celebA_distractors/093074.jpg',
 './celebA_ir/celebA_distractors/001516.jpg',
 './celebA_ir/celebA_distractors/083267.jpg',
 './celebA_ir/celebA_distractors/144738.jpg',
 './celebA_ir/celebA_distractors/006279.jpg',
 './celebA_ir/celebA_distractors/082179.jpg',
 './celebA_ir/celebA_distractors/085170.jpg',
 './celebA_ir/celebA_distractors/041654.jpg',
 './celebA_ir/celebA_distractors/105764.jpg',
 './celebA_ir/celebA_distractors/048116.jpg',
 './celebA_ir/celebA_distractors/015205.jpg',
 './celebA_ir/celebA_distractors/002779.jpg',
 './celebA_ir/celebA_distractors/146649.jpg',
 './celebA_ir/celebA_distractors/099559.jpg',
 './celebA_ir/celebA_distractors/151335.jpg',
 './celebA_ir/celebA_distractors/081316.jpg',
 './celebA_ir/celebA_distractors/032862.jpg',
 './celebA_ir/celebA_distractors/0

In [17]:
query_dict['35']

['./celebA_ir/celebA_query/001265.jpg',
 './celebA_ir/celebA_query/001430.jpg',
 './celebA_ir/celebA_query/012834.jpg',
 './celebA_ir/celebA_query/041171.jpg',
 './celebA_ir/celebA_query/041823.jpg',
 './celebA_ir/celebA_query/052547.jpg',
 './celebA_ir/celebA_query/071369.jpg',
 './celebA_ir/celebA_query/087722.jpg',
 './celebA_ir/celebA_query/101493.jpg',
 './celebA_ir/celebA_query/113930.jpg',
 './celebA_ir/celebA_query/133927.jpg',
 './celebA_ir/celebA_query/136309.jpg',
 './celebA_ir/celebA_query/140075.jpg',
 './celebA_ir/celebA_query/153974.jpg',
 './celebA_ir/celebA_query/154402.jpg',
 './celebA_ir/celebA_query/161094.jpg']

In [90]:
import torch
import torch.nn.functional as F

class SimpleNN(torch.nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=3, 
            out_channels=64, 
            padding=1, 
            kernel_size=3, 
            stride=1, 
            bias=False)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(in_features=531200, out_features=3, bias=False)
        
    def forward(self,  x):
        
        x = self.conv1(x)
        x = F.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        
        return x

In [91]:
model = SimpleNN()



In [118]:
from typing import List, Dict
import PIL.Image as Image
from torchvision.transforms import transforms

tf = transforms.Compose([transforms.ToTensor()])
cosine_similarity = torch.nn.CosineSimilarity(dim=0)


def image_list_to_tensor(image_path_list: List[str]) -> torch.Tensor: 
    """
    Transforms  image list to tensor of NxCxHxW
    :param image_path_list: 
    :return: 
    """
    tensor_list = []
    for image_path in image_path_list:
        image = Image.open(image_path)
        tensor_list.append(tf(image).unsqueeze(0))
    return torch.cat(tensor_list, dim=0).detach()


tensor_dict = {}

for key in query_dict:
    file_names = query_dict[key]
    tensor_dict[key] = image_list_to_tensor(file_names)
    
    
embeddings_dict = {}

for key in tensor_dict:
    image_tensor = tensor_dict[key]
    embeddings_dict[key] = model(image_tensor).detach()
    
    
def get_embeddings(model: torch.nn.Module, images_dict: Dict[str, List[str]]) -> Dict[str, torch.Tensor]:   
    """
    Convert image paths dictionary to dictionary of embeddings tensors
    :param model: 
    :param images_dict: 
    :return: 
    """
    embeddings = {}
    for image in images_dict:
        image_tensor = images_dict[image]
        embeddings[key] = model(image_tensor).detach()
    return embeddings   


# same faces similarities
similarities = []
for person_id_key in embeddings_dict:
    person_embeddings = embeddings_dict[person_id_key]
    person_photos_len = person_embeddings.shape[0]
    for i in range(person_photos_len):
         for j in range(i, person_photos_len):
             if i != j:
                tensor_i = person_embeddings[i].detach()
                tensor_j = person_embeddings[j].detach()
                cos = cosine_similarity(tensor_i, tensor_j)
                similarities.append(cos.item())
            
                             
                             
# cross similarities                
different_persons_similarities = []
keys = list(embeddings_dict.keys())
person_ids_len = len(keys)
for i in range(person_ids_len):
    for j in range(i, person_ids_len):
        if i != j:
            i_key = keys[i]
            j_key = keys[j]
            i_embeddings = embeddings_dict[i_key]
            j_embeddings = embeddings_dict[j_key]
            i_embeddings_len = i_embeddings.shape[0]
            j_embeddings_len = j_embeddings.shape[0]
            for k in range(i_embeddings_len):
                for l in range(j_embeddings_len):
                    tensor_k = i_embeddings[k].detach()
                    tensor_l = j_embeddings[l].detach()
                    cos = cosine_similarity(tensor_k, tensor_l)
                    different_persons_similarities.append(cos.item())

[0.941612958908081,
 0.9967120885848999,
 0.8973538875579834,
 0.9463392496109009,
 0.9607146382331848,
 0.9833064079284668,
 0.9230925440788269,
 0.9961410164833069,
 0.9308042526245117,
 0.4979683458805084,
 0.9180880784988403,
 0.8231239914894104,
 0.9846174716949463,
 0.9797742366790771,
 0.9162795543670654,
 0.9963861107826233,
 0.7861795425415039,
 0.9611412882804871,
 0.9798731803894043,
 0.989205002784729,
 0.993726372718811,
 0.9109559059143066,
 0.9502905011177063,
 0.7794703245162964,
 0.9252890348434448,
 0.9403864741325378,
 0.7853130102157593,
 0.8265407085418701,
 0.8925279974937439,
 0.7413716912269592,
 0.9060980081558228,
 0.9689602851867676,
 0.6350346803665161,
 0.9992460012435913,
 0.9526197910308838,
 0.9316401481628418,
 0.8650388121604919,
 0.9993219375610352,
 0.8954790830612183,
 0.9533030986785889,
 0.8250319361686707,
 0.9843555688858032,
 0.8873606324195862,
 0.9650113582611084,
 0.7871677875518799,
 0.815348207950592,
 0.9528907537460327,
 0.99455344676971

In [122]:
distractors_tensors = []
for image in distractors_img_path:
    image = Image.open(image)
    tensor = tf(image)
    distractors_tensors.append(tensor.unsqueeze(0))
distractors_batch = torch.cat(distractors_tensors, dim=0)    

torch.Size([2001, 3, 100, 83])

In [124]:
distractor_embeddings = model(distractors_batch).detach()

In [125]:
distractor_embeddings.shape

torch.Size([2001, 3])

In [126]:
query_distractor_cosine_similarity = []
for key in embeddings_dict:
    query_embeddings = embeddings_dict[key]
    query_embeddings_len = query_embeddings.shape[0]
    for i in range(query_embeddings_len):
        distractor_embeddings_len = distractor_embeddings.shape[0]
        for j in range(distractor_embeddings_len):
            query_embedding = query_embeddings[i]
            distractor_embedding = distractor_embeddings[j]
            cos = cosine_similarity(query_embedding, distractor_embedding)
            query_distractor_cosine_similarity.append(cos.item())

In [129]:
query_distractor_cosine_similarity

[0.9719704389572144,
 0.8875658512115479,
 0.9157472848892212,
 0.9198703765869141,
 0.9419053792953491,
 0.9815024137496948,
 0.9653341770172119,
 0.7904399633407593,
 0.9912769198417664,
 0.9645033478736877,
 0.9871326684951782,
 0.9574835300445557,
 0.9370778799057007,
 0.9524180889129639,
 0.9950231313705444,
 0.9635204076766968,
 0.9971747398376465,
 0.987957775592804,
 0.9794161915779114,
 0.968396008014679,
 0.9118874073028564,
 0.9417448043823242,
 0.9397492408752441,
 0.8246406316757202,
 0.8103398084640503,
 0.9555126428604126,
 0.9669681787490845,
 0.9396635293960571,
 0.9708079099655151,
 0.9789514541625977,
 0.9821045994758606,
 0.9244005680084229,
 0.9377046227455139,
 0.8511788249015808,
 0.9981542825698853,
 0.8824886083602905,
 0.9794238209724426,
 0.9290727376937866,
 0.9648818969726562,
 0.9912770986557007,
 0.9496690034866333,
 0.8680692911148071,
 0.9892759323120117,
 0.9840193390846252,
 0.9968096017837524,
 0.992196261882782,
 0.968858003616333,
 0.99183946847915

In [130]:
false_pairs = different_persons_similarities + query_distractor_cosine_similarity

In [132]:
len(false_pairs)

3176532

In [133]:
false_positive_rate = 0.01

In [137]:
N = int(false_positive_rate * len(false_pairs))

In [138]:
N

31765

In [139]:
false_pairs = sorted(false_pairs)

In [140]:
false_pairs

[-0.55887371301651,
 -0.49167290329933167,
 -0.48837465047836304,
 -0.4849374294281006,
 -0.4491637349128723,
 -0.4432150423526764,
 -0.44032979011535645,
 -0.42208582162857056,
 -0.4173530042171478,
 -0.4117407500743866,
 -0.40997666120529175,
 -0.4030684530735016,
 -0.39724159240722656,
 -0.3921506702899933,
 -0.3823157548904419,
 -0.3814927637577057,
 -0.37963366508483887,
 -0.37704402208328247,
 -0.3737722337245941,
 -0.36958760023117065,
 -0.36184120178222656,
 -0.3503398895263672,
 -0.348085880279541,
 -0.34497082233428955,
 -0.3355538547039032,
 -0.33529558777809143,
 -0.33406126499176025,
 -0.3261409401893616,
 -0.3222431540489197,
 -0.31363770365715027,
 -0.3121674358844757,
 -0.31136706471443176,
 -0.30797523260116577,
 -0.30765384435653687,
 -0.30752894282341003,
 -0.3017716109752655,
 -0.300224244594574,
 -0.296191543340683,
 -0.2946726381778717,
 -0.29312068223953247,
 -0.2895382046699524,
 -0.2894706726074219,
 -0.28895190358161926,
 -0.2887611389160156,
 -0.2886454761028