In [None]:
import torch
import numpy as np
from typing import List
from facenet_pytorch.models.utils.detect_face import crop_resize
from PIL import Image
import seaborn as sns
sns.set(style="white")
import matplotlib.pyplot as plt
%matplotlib inline

from mlmodule.box import BBoxOutput
from mlmodule.torch.data.box import BoundingBoxDataset
from mlmodule.contrib.mtcnn import MTCNNDetector
from mlmodule.contrib.magface import MagFaceFeatures
from mlmodule.torch.data.images import ImageDataset
from mlmodule.utils import list_files_in_dir


In [None]:
%env AWS_ACCESS_KEY_ID = <please your key id here>
%env AWS_SECRET_ACCESS_KEY = <please your secret access key here>

In [None]:
# load models
device = torch.device('cuda:0')
magface = MagFaceFeatures(device=device).load()
mtcnn = MTCNNDetector(device=device, image_size=(720,1080), min_face_size=20).load()

In [None]:
# run face detection first
base_path = "../tests/fixtures/remi_faces"
file_names = list_files_in_dir(base_path, allowed_extensions=('jpg',))
dataset = ImageDataset(file_names)
# Detect faces first
file_names, outputs = mtcnn.bulk_inference(dataset)


In [None]:
# Flattening all detected faces
bboxes: List[BBoxOutput]
indices: List[str]
indices, file_names, bboxes = zip(*[
    (f'{fn}_{i}', fn, bbox) for fn, bbox_list in zip(file_names, outputs) for i, bbox in enumerate(bbox_list)
])
# Create a dataset for the bounding boxes
bbox_features = BoundingBoxDataset(indices, file_names, bboxes)

# Get face features
d_indices, features = magface.bulk_inference(
    bbox_features,
    remove_bad_quality_faces=False,
    data_loader_options={'batch_size': 12,
                         'num_workers': 0, 'pin_memory': True},
    tqdm_enabled=True)


## Show face quality using feature magnitudes

In [None]:
def image_grid(array, ncols=10):
    index, height, width, channels = array.shape
    nrows = index//ncols

    img_grid = (array.reshape(nrows, ncols, height, width, channels)
                .swapaxes(1, 2)
                .reshape(height*nrows, width*ncols, channels))

    return img_grid

def display_faces_with_magnitude(features, file_names, bboxes, ncols=10):
    # compute feature magnitudes
    mags = torch.linalg.norm(torch.tensor(features), dim=1)
    sort_idx = torch.argsort(mags)

    img_arr = []
    for ele in sort_idx:
        img = Image.open(file_names[ele])
        box = np.array([bboxes[ele].bounding_box[0].x,
                    bboxes[ele].bounding_box[0].y, bboxes[ele].bounding_box[1].x,
                    bboxes[ele].bounding_box[1].y])
        cropped_face = np.asarray(crop_resize(img, box, image_size=112))
        img_arr.append(cropped_face)

    if len(img_arr)%ncols:
        for i in range(len(img_arr), (len(img_arr)//ncols+1)*ncols):
            img_arr.append(255 * np.ones((112, 112, 3), np.uint8))

    result = image_grid(np.array(img_arr), ncols=ncols)
    fig = plt.figure(figsize=(20., 20.))
    plt.imshow(result)
    print('feature magnitude: {}'.format([float('{0:.2f}'.format(mags[idx_].item())) for idx_ in sort_idx]))
    return sort_idx

In [None]:
sort_idx = display_faces_with_magnitude(features, file_names, bboxes, ncols=11)

## Face similarity

In [None]:
normalized_features = torch.nn.functional.normalize(torch.tensor(features))[sort_idx]
sim_mat = normalized_features @ normalized_features.T
fig, ax = plt.subplots(figsize=(8, 6))
ax = sns.heatmap(sim_mat, cmap="PuRd", annot=True)


## Face quality analysis of faces from Office

In [None]:
mtcnn = MTCNNDetector(device=device, image_size=(720, 720), min_face_size=20).load()  
# run face detection first
base_path = "../tests/fixtures/faces"
file_names = list_files_in_dir(base_path, allowed_extensions=('jpg',))
dataset = ImageDataset(file_names)
# Detect faces first
file_names, outputs = mtcnn.bulk_inference(dataset)


In [None]:
# Flattening all detected faces
bboxes: List[BBoxOutput]
indices: List[str]
indices, file_names, bboxes = zip(*[
    (f'{fn}_{i}', fn, bbox) for fn, bbox_list in zip(file_names, outputs) for i, bbox in enumerate(bbox_list)
])
# Create a dataset for the bounding boxes
bbox_features = BoundingBoxDataset(indices, file_names, bboxes)

# Get face features
d_indices, features = magface.bulk_inference(
    bbox_features,
    remove_bad_quality_faces=False,
    data_loader_options={'batch_size': 12,
                         'num_workers': 0, 'pin_memory': True},
    tqdm_enabled=True)


In [None]:
sort_idx = display_faces_with_magnitude(features, file_names, bboxes, ncols=10)

In [None]:
normalized_features = torch.nn.functional.normalize(torch.tensor(features))[sort_idx]
sim_mat = normalized_features @ normalized_features.T
fig, ax = plt.subplots(figsize=(20, 20))
ax = sns.heatmap(sim_mat, cmap="PuRd", annot=True, cbar=False, fmt=".1f")
