In [None]:
from mlmodule.utils import list_files_in_dir
from mlmodule.torch.data.images import ImageDataset
from mlmodule.contrib.arcface import ArcFaceFeatures
from mlmodule.contrib.magface import MagFaceFeatures
from mlmodule.contrib.mtcnn import MTCNNDetector
from mlmodule.torch.data.box import BoundingBoxDataset
from mlmodule.box import BBoxOutput
import matplotlib.pyplot as plt
import torch
import numpy as np
from typing import List
from facenet_pytorch.models.utils.detect_face import crop_resize
from PIL import Image
import seaborn as sns
sns.set(style="white")
%matplotlib inline


In [None]:
%env AWS_ACCESS_KEY_ID = <please your key id here>
%env AWS_SECRET_ACCESS_KEY = <please your secret access key here>

In [None]:
# load models
device = torch.device('cuda:0')
arcface = ArcFaceFeatures(device=device).load()
magface = MagFaceFeatures(device=device).load()
mtcnn = MTCNNDetector(device=device, image_size=(720, 720), min_face_size=20).load()


In [None]:
# run face detection first
base_path = "../tests/fixtures/berset"
file_names = list_files_in_dir(base_path, allowed_extensions=('jpg',))
dataset = ImageDataset(file_names)
# Detect faces first
file_names, outputs = mtcnn.bulk_inference(dataset)

In [None]:
# Flattening all detected faces
bboxes: List[BBoxOutput]
indices: List[str]
indices, file_names, bboxes = zip(*[
    (f'{fn}_{i}', fn, bbox) for fn, bbox_list in zip(file_names, outputs) for i, bbox in enumerate(bbox_list)
])

In [None]:
# Get face features with ArcFace
# 1. Create a dataset for the bounding boxes
bbox_features = BoundingBoxDataset(indices, file_names, bboxes)
# 2. Get features from ArcFace
d_indices, arcface_features = arcface.bulk_inference(
    bbox_features,
    remove_bad_quality_faces=False,
    data_loader_options={'batch_size': 12,
                         'num_workers': 0, 'pin_memory': True},
    tqdm_enabled=True)

# Get face features with MagFace
# 1. Create a dataset for the bounding boxes
bbox_features = BoundingBoxDataset(indices, file_names, bboxes)
# 2. Get features from MagFace
d_indices, magface_features = magface.bulk_inference(
    bbox_features,
    remove_bad_quality_faces=False,
    data_loader_options={'batch_size': 12,
                         'num_workers': 0, 'pin_memory': True},
    tqdm_enabled=True)


In [None]:
def image_grid(array, ncols=10):
    index, height, width, channels = array.shape
    nrows = index//ncols

    img_grid = (array.reshape(nrows, ncols, height, width, channels)
                .swapaxes(1, 2)
                .reshape(height*nrows, width*ncols, channels))

    return img_grid


img_arr = []
img_size = []
aspect_ratios = []
for ele, file_name in enumerate(file_names):
    img = Image.open(file_name)
    img_size.append(img.size)
    width, height = img.size
    aspect_ratios.append([x/target for x, target in zip((height, width), (720, 720))])
    box = np.array([bboxes[ele].bounding_box[0].x,
                   bboxes[ele].bounding_box[0].y, bboxes[ele].bounding_box[1].x,
                   bboxes[ele].bounding_box[1].y])
    cropped_face = np.asarray(crop_resize(img, box, image_size=112))
    img_arr.append(cropped_face)

result = image_grid(np.array(img_arr), ncols=len(file_names))
fig = plt.figure(figsize=(20., 20.))
plt.imshow(result)


## Face similarity with ArcFace

In [None]:
sim_mat = arcface_features @ arcface_features.T
fig, ax = plt.subplots(figsize=(8, 6))
ax = sns.heatmap(sim_mat, cmap="PuRd", annot=True)


## Face Similarity with MagFace

In [None]:
normalized_features = magface_features / np.linalg.norm(magface_features, axis=1, keepdims=True)
sim_mat = normalized_features @ normalized_features.T
fig, ax = plt.subplots(figsize=(8, 6))
ax = sns.heatmap(sim_mat, cmap="PuRd", annot=True)
