## Install Packages

In [12]:
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, SequentialSampler
from torchvision import datasets
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from src.utils.triplet_loss import BatchAllTtripletLoss
from tqdm.notebook import tqdm
from src.utils.celeba_helper import CelebADataset, CelebAClassifier, save_file_names, CelebADatasetTriplet, get_train_files_for_max_img_per_person
from src.utils.loss_functions import TripletLoss
from src.utils.similarity_functions import euclidean_distance_matrix
from importlib import reload

workers = 0 if os.name == 'nt' else 2
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Define CelebA Dataset and Loader

In [13]:
# Create training and testing dataframes

mapping_file = 'data/identity_CelebA.txt'

file_label_mapping = pd.read_csv(
    mapping_file, header=None, sep=" ", names=["file_name", "person_id", "is_train"]
)

train_files = get_train_files_for_max_img_per_person(file_label_mapping=file_label_mapping, max_img_pp=5)

file_label_mapping.loc[:, 'is_train'] = 0
file_label_mapping.loc[file_label_mapping['file_name'].isin(train_files), 'is_train'] = 1
file_label_mapping['file_id'] = [int(elem[:6])-1 for elem in file_label_mapping['file_name'].values]

train_df = file_label_mapping[file_label_mapping['is_train']==1]
test_df = file_label_mapping[file_label_mapping['is_train']==0]

In [14]:
## Load the dataset
# Path to directory with all the images
img_folder = 'data/img_align_celeba'
mapping_file = 'data/identity_CelebA.txt'

image_size = 160
transform=transforms.Compose([ #TODO: Add standardization
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor()
])

# Load the dataset from file and apply transformations
celeba_dataset = CelebADataset(img_folder, mapping_file, transform)

In [15]:
## Create a dataloader
# Batch size during training
batch_size = 8
# Number of workers for the dataloader
num_workers = 0 if device.type == 'cuda' else 2
# Whether to put fetched data tensors to pinned memory
pin_memory = True if device.type == 'cuda' else False

celeba_dataloader = DataLoader(celeba_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                shuffle=False)

In [16]:
train_inds = train_df['file_id'].values.tolist()

train_loader = DataLoader(
    celeba_dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)

# FaceNet Training Pipeline

## Initializing the resnet model, optimizer and loss function

In [17]:
resnet = InceptionResnetV1(pretrained='vggface2').to(device)
optimizer = optim.Adam(resnet.parameters(), lr=0.0001)
criterion = BatchAllTtripletLoss()
eps = 1e-8 # constant to ensure no division by 0

In [11]:
def find_positive_observations(X, y, df):
    """Find the positive observations in the supplied dataset for each observation in X 
    and adds features and labels to X and y, respectively.

    Args:
        X (tensor): Features of images. Shape: [batch_size, channels, width, height]
        y (tensor): Labels: Shape: [batch_size]
        df (pd.DataFrame): Dataframe that contains mapping of file IDs and labels ('person_id')

    Returns:
        (tensor, tensor): _description_
    """
    
    pos_obs_idx = np.array([], dtype=int)
    for anchor in np.unique(y):
        # get file_ids of all positive examples for anchor
        pos_obs_idx = np.hstack([pos_obs_idx, df[df['person_id']==int(anchor)]['file_id'].values])

    for pos_obs in pos_obs_idx:
        # get image and label of positive example
        pos_img, pos_label = celeba_dataset[pos_obs]
        # add to batch
        X = torch.cat((X, torch.unsqueeze(pos_img, 0)), dim=0)
        y = torch.cat((y, torch.tensor([pos_label])), dim=0)

    return X, y

## Training steps

In [16]:
resnet.train()
epochs = 5
loss_total = []

for epoch in tqdm(range(epochs), desc="Epochs", leave=True, ncols=80, position=0):
    running_loss = []
    for idx, (X, y) in enumerate(tqdm(train_loader, desc="Current Batch", ncols=80, position=1, leave=False)):
        X, y = find_positive_observations(X, y, train_df)
        
        # Create embeddings
        X_emb = resnet(X.to(device))
        optimizer.zero_grad()

        loss = criterion(X_emb, y.to(device))
        loss.backward()
        optimizer.step()

        running_loss.append(loss.cpu().detach().numpy())
        
    loss_total.append(np.mean(running_loss))
    print("Epoch: {}/{} - Loss: {:.4f}".format(epoch, epochs, np.mean(running_loss)))


Epochs:   0%|                                             | 0/5 [00:00<?, ?it/s]

Current Batch:   0%|                                   | 0/6129 [00:00<?, ?it/s]

tensor(0.3795, grad_fn=<DivBackward0>)
tensor(0.4961, grad_fn=<DivBackward0>)
tensor(0.4705, grad_fn=<DivBackward0>)
tensor(0.4780, grad_fn=<DivBackward0>)


## Plotting Loss curve

In [None]:
# printing loss function
plt.plot(loss_total)
plt.xlabel("Epochs")
plt.ylabel("TripletLoss")
plt.title("Training loss")
plt.show()

## Testing the trained model:

In [None]:
resnet.eval().to(device)

## Accuracy of the model

In [None]:
vault_path = "data/oneshot_vault"
label_file = "data/identity_vault_person.txt"

def load_image(path, transform):
    img = Image.open(path).convert("RGB")
    if transform:
            img = transform(img)
    return img
    
def create_embeddings(folder, label_file, model, transform):
    label_file_dict = {}
    gt_labels = []
    with open(label_file, 'r') as r_file:
        for file in r_file:
            file = file.strip("\n").split(" ")
            if file[0] not in label_file_dict:
                label_file_dict[file[0]] = file[1]

    embeddings = torch.empty(len(label_file_dict), 512)
    for i, file in enumerate(label_file_dict.keys()):
        img = load_image(os.path.join(folder, file), transform)

        img_emb = model(img[None, :])

        embeddings[i] = img_emb
        gt_labels.append(label_file_dict[file])

    return embeddings, gt_labels


resnet.eval().to(device)

transform=transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor()
])
embeddings, gt_labels = create_embeddings(folder=vault_path, label_file=label_file, 
                        model=resnet, transform=transform)
                        

In [None]:
# Test image:
def calculate_label(test_image_file, img_folder, transform, embeddings):
# test_image_file = "s1_9.pgm"
    test_file_path = os.path.join(img_folder, test_image_file)
    test_img = load_image(test_file_path, transform=transform)


    test_img_emb = resnet(test_img[None, :])
    test_img_emb = torch.squeeze(test_img_emb, 0)
    # print(f'Shape of test: {test_img_emb.shape}')
    # print(f'Shape of embeddings: {embeddings.shape}')

    distance_mat = (test_img_emb - embeddings).pow(2).sum(axis=1)
    # print(distance_mat)
    test_label_pred = gt_labels[torch.argmin(distance_mat.abs())]
    # print(f'Ground truth label: {test_image_file.split("_")[0][1]}')
    # print(f'Predicted label: {test_label_pred}')

    return int(test_label_pred)

# testing on one test image
test_image_file = "000032.jpg"
calculate_label(test_image_file, img_folder, transform, embeddings)


In [None]:

with open(mapping_file, 'r') as test_labels_file:
    test_labels = {}
    for i, file in enumerate(test_labels_file):
        file = file.strip("\n").split(" ")
        test_labels[file[0]] = file[1]
        # if i > 300:
        #     break

test_predictions = []
test_gt_labels = []

for i, (file, label) in enumerate(test_labels.items()):
    test_gt_labels.append(int(label))

    test_label_pred = calculate_label(file, img_folder, transform, embeddings)
    test_predictions.append(test_label_pred)

accuracy = torch.tensor(test_predictions) == torch.tensor(test_gt_labels)
accuracy = accuracy.int().sum()/len(accuracy)
print(f'Accuracy for the model: {accuracy}')

