In [1]:
import numpy as np
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization
import torch
from matplotlib import pyplot as plt
import pandas as pd
import os
import torch
from PIL import Image
from torchvision import transforms
from tqdm.notebook import tqdm
import pickle
import random
from sklearn.metrics import roc_curve

workers = 0 if os.name == 'nt' else 2

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
img_folder = '../one-shot-face-recognition-main-1125/all_images_mtcnn'
mapping_file = '../one-shot-face-recognition-main-1125/identity_CelebA_all.txt'

splits_file = "test_splits_FE.pkl"
embedding_folder = './embeddings_FE'

model_statedict = './facenet_model_statedict_epochs150_margin0.5_lr0.1_schedule40_70_95_130.pth'

# exclude augmentation images

In [4]:
image_names = [i for i in os.listdir(img_folder) if i.lower().endswith('.jpg') and '_' not in i]
len(image_names)

202599

In [5]:
mapping = pd.read_csv(
            mapping_file, header=None, sep=" ", names=["file_name", "person_id"])

dict_image_person = {}
dict_person_images = {}
for index, row in mapping.iterrows():
    if '_' in row['file_name']:
        continue
    if row['person_id'] not in dict_person_images:
        dict_person_images[row['person_id']] = []
    dict_person_images[row['person_id']].append(row['file_name'])
    #file_name = row['file_name'].split('.')[0]
    dict_image_person[row['file_name']] = row['person_id']
    
len(dict_image_person), len(dict_person_images)

(202599, 10177)

# create ten splits

In [None]:
people = list(dict_person_images.keys())
len(people)

In [None]:
random.shuffle(people)
people_sets = []
for i in range(10):
    s = people[i*1000:i*1000+1000]
    people_sets.append(s)
  

In [None]:
dict_set_pairs = {}
for i in range(len(people_sets)):
    print('i:',i)
    people_set = people_sets[i]
    dict_set_pairs[i] = {'same':set(), 'diff':set()}
    while len(dict_set_pairs[i]['same']) < 1000:
        person = random.sample(people_set, 1)[0]
        positive_list = dict_person_images[person]
        if len(positive_list)>1:
            pair = tuple(sorted(random.sample(positive_list, 2)))
            if pair not in dict_set_pairs[i]['same']:
                dict_set_pairs[i]['same'].add(pair)
    while len(dict_set_pairs[i]['diff']) < 1000:
        persons = random.sample(people_set, 2)
        img1 = random.sample(dict_person_images[persons[0]], 1)[0]
        img2 = random.sample(dict_person_images[persons[1]], 1)[0]
        pair = tuple(sorted([img1,img2]))
        if pair not in dict_set_pairs[i]['diff']:
            dict_set_pairs[i]['diff'].add(pair)

In [None]:
with open(splits_file, "wb") as file:
    pickle.dump(dict_set_pairs, file)

In [6]:
with open(splits_file, "rb") as file:
    dict_set_pairs = pickle.load(file)

# compute and save embeddings

In [7]:
resnet = InceptionResnetV1(pretrained='vggface2').to(device)
resnet.load_state_dict(torch.load(model_statedict, map_location=torch.device('cpu')))
_ = resnet.eval()

In [8]:
def get_emb(f):
    path = os.path.join(img_folder, f)
    image = Image.open(path).convert("RGB")
    image_size = 160
    transform=transforms.Compose([
        transforms.Resize((image_size, image_size)),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization
    ])
    img = transform(image)
    e = resnet(img[None,:].to(device)).detach().cpu()
    return e

In [9]:
os.makedirs(embedding_folder, exist_ok=True)

In [None]:
all_images_in_pairs = []
for i in range(len(dict_set_pairs)):
    for j in ['same', 'diff']:
        for pair in dict_set_pairs[i][j]:
            all_images_in_pairs.append(pair[0])
            all_images_in_pairs.append(pair[1])
print(len(all_images_in_pairs))
all_images_in_pairs = list(set(all_images_in_pairs))
print(len(all_images_in_pairs))


for i in tqdm(range(len(all_images_in_pairs))):
    image_name = all_images_in_pairs[i]
    e_path = os.path.join(embedding_folder, image_name.replace('.jpg', '.pt'))
    if not os.path.exists(e_path):
        e = get_emb(image_name)
        torch.save(e, e_path)


40000
34870


  0%|          | 0/34870 [00:00<?, ?it/s]

In [None]:
e1 = get_emb(dict_person_images[1][5])
e2 = get_emb(dict_person_images[1][2])
(e2-e1).norm()

In [None]:
e3 = get_emb(image_names[123])
(e3-e1).norm()

# cross validation

In [None]:
with open("test_splits_FE.pkl", "rb") as file:
    dict_set_pairs = pickle.load(file)

In [None]:
def load_emb(f):
    e_path = os.path.join(embedding_folder, f.replace('.jpg', '.pt'))
    e = torch.load(e_path)
    return e


In [None]:
dict_set_distances = {}
for i in range(len(dict_set_pairs)):
    print('i:',i)
    dict_set_distances[i] = {'same':[], 'diff':[]}
    for pair in dict_set_pairs[i]['same']:
        e1 = load_emb(pair[0])
        e2 = load_emb(pair[1])
        d = (e1-e2).norm()
        dict_set_distances[i]['same'].append(float(d))
    for pair in dict_set_pairs[i]['diff']:
        e1 = load_emb(pair[0])
        e2 = load_emb(pair[1])
        d = (e1-e2).norm()
        dict_set_distances[i]['diff'].append(float(d))

In [None]:
def opt_threshold(value_label):
    y_true = []
    y_score = []
    for v, label in value_label:
        y_true.append(0 if label=='same' else 1)
        y_score.append(v)
    
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    
    # Calculate the G-mean
    gmean = np.sqrt(tpr * (1 - fpr))

    # Find the optimal threshold
    index = np.argmax(gmean)
    thresholdOpt = round(thresholds[index], ndigits = 4)
    gmeanOpt = round(gmean[index], ndigits = 4)
    fprOpt = round(fpr[index], ndigits = 4)
    tprOpt = round(tpr[index], ndigits = 4)
    
    return thresholdOpt

def get_acc(value_label, thresholdOpt):
    FP = 0
    FN = 0
    for e in value_label:
        if e[1]=='same' and e[0]>thresholdOpt:
            FP += 1
        if e[1]=='diff' and e[0]<thresholdOpt:
            FN += 1

    acc = 1- (FP+FN)/len(value_label)
    return acc

def cv(hold_out):
    dist_same_train = []
    dist_diff_train = []
    for i in range(len(dict_set_distances)):
        if i!=hold_out:
            dist_same_train += dict_set_distances[i]['same']
            dist_diff_train += dict_set_distances[i]['diff']

    value_label = [(d,'same') for d in dist_same_train] + [(d,'diff') for d in dist_diff_train]
    value_label = sorted(value_label, key=lambda x:x[0])
    
    thresholdOpt = opt_threshold(value_label)
    
    dist_same_test = dict_set_distances[hold_out]['same']
    dist_diff_test = dict_set_distances[hold_out]['diff']

    value_label = [(d,'same') for d in dist_same_test] + [(d,'diff') for d in dist_diff_test]
    value_label = sorted(value_label, key=lambda x:x[0])
    
    test_acc = get_acc(value_label, thresholdOpt)
    
    return thresholdOpt, test_acc

In [None]:
thresholds = []
accs = []
for i in range(len(dict_set_distances)):
    thresholdOpt, test_acc = cv(i)
    thresholds.append(thresholdOpt)
    accs.append(test_acc)
    print('fold: %s \t optimal threshold: %s \t test acc: %s'%(i, thresholdOpt, round(test_acc,4)))
print('\nAvg: threshold %s \t acc %s'%(np.mean(thresholds), np.mean(accs)))