## Install Packages

In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
import zipfile 
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score
from src.utils.celeba_dataset import CelebADataset

workers = 0 if os.name == 'nt' else 2

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


# Define CelebA Dataset and Loader

In [None]:
## Load the dataset
# Path to directory with all the images
img_folder = 'data/img_align_celeba'
mapping_file = 'data/identity_CelebA.txt'

# Spatial size of training images, images are resized to this size.
image_size = 160
transform=transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor()
])

# Load the dataset from file and apply transformations
celeba_dataset = CelebADataset(img_folder, mapping_file, transform)

In [None]:
## Create a dataloader
# Batch size during training
batch_size = 128
# Number of workers for the dataloader
num_workers = 0 if device.type == 'cuda' else 2
# Whether to put fetched data tensors to pinned memory
pin_memory = True if device.type == 'cuda' else False

celeba_dataloader = torch.utils.data.DataLoader(celeba_dataset,
                                                batch_size=batch_size,
                                                num_workers=num_workers,
                                                pin_memory=pin_memory,
                                                shuffle=False)

# Setup FaceNet

## Define MTCNN module

In [173]:
mtcnn = MTCNN(
    image_size=image_size, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, keep_all=False,
    device=device
)

## Define Inception Resnet V1 module

In [174]:
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

## Load data

In [None]:
load_new_data = False
file_number_to_load = '028289'
num_batches = 220

In [160]:
# Custom data loading function

def load_data(dataloader: torch.utils.data.DataLoader, num_batches: int):
    embeddings = None
    face_file_names = []

    for idx, batch in enumerate(dataloader):
        aligned = torch.tensor([])
        train_features, file_names = batch

        for train_feature, file_name in zip(train_features, file_names):
            img = transforms.ToPILImage()(train_feature)
            feature_aligned, prob = mtcnn(img, return_prob=True)
            if feature_aligned is not None:
                feature_aligned = feature_aligned.reshape([1, 3, image_size, image_size])
                face_file_names.append(file_name)
                if len(aligned) == 0:
                    aligned = feature_aligned
                else:
                    aligned = torch.cat([aligned, feature_aligned])

        print(f'Batch {idx}. Batch shape: {aligned.shape}')
        aligned = aligned.to(device)
        batch_embeddings = resnet(aligned).detach().cpu()

        if embeddings == None:
            embeddings = batch_embeddings
        else:
            embeddings = torch.cat([embeddings, batch_embeddings])

        if idx == num_batches:
            break

    return embeddings, face_file_names

if load_new_data:
    train_embeddings, train_face_file_names = load_data(celeba_dataloader, num_batches)
    torch.save(train_embeddings, f'pytorch_objects/embeddings_up_to_img_{face_file_names[-1][0:-4]}.pickle')
    with open(f'pytorch_objects/file_names_up_to_img_{face_file_names[-1][0:-4]}', 'w') as fp:
        for item in face_file_names:
            # write each item on a new line
            fp.write("%s\n" % item)
        print('Done')
else:
    train_embeddings = torch.load(f'pytorch_objects/embeddings_up_to_img_{file_number_to_load}.pickle')
    train_face_file_names = []
    with open(f'pytorch_objects/file_names_up_to_img_{file_number_to_load}', 'r') as fp:
        for line in fp:
            x = line[:-1]
            # add current item to the list
            train_face_file_names.append(x)


print(train_embeddings.shape)

torch.Size([28246, 512])


In [161]:
train_labels = file_label_mapping[file_label_mapping['file_name'].isin(train_face_file_names)]['person_id'].values
print(f'Number of people in train dataset: {len(np.unique(train_labels))}')

Number of people in train dataset: 7390


In [144]:
test_embeddings, test_face_file_names = load_data(celeba_test_dataloader, 10)

# Get labels for test dataset from mapping dataframe
test_labels = list(file_label_mapping[file_label_mapping['file_name'].isin(test_face_file_names)]['person_id'].values)

Batch 0. Batch shape: torch.Size([128, 3, 160, 160])
Batch 1. Batch shape: torch.Size([128, 3, 160, 160])
Batch 2. Batch shape: torch.Size([15, 3, 160, 160])
torch.Size([271, 512])


## Make Predictions

In [None]:
def predict(test_embeddings: torch.tensor, train_embeddings: torch.tensor, file_label_mapping):
    # Calculate distance for the test dataset and calculate accuracy
    accuracy = 0
    predictions = []
    predictions_files = []
    test_set_size = len(test_embeddings)
    
    for idx, test_embedding in enumerate(test_embeddings):
        dists = [(test_embedding - e1).norm().item() for e1 in train_embeddings]
        closest_image_file_name = face_file_names[np.argmin(dists)]
        predicted_person_id = file_label_mapping[file_label_mapping['file_name'] == closest_image_file_name]['person_id'].values[0]

        predictions.append(predicted_person_id)
        predictions_files.append(closest_image_file_name)

    return predictions, predictions_files

test_predictions, test_predictions_files = predict(test_embeddings, train_embeddings, file_label_mapping)

In [163]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(test_labels, test_predictions)
print(f'Accuracy: {np.round(accuracy, 4)}')

Accuracy: 0.8007
