In [None]:
# !pip install facenet-pytorch

In [1]:
from facenet_pytorch import MTCNN
import torch
import numpy as np
from glob import glob
import PIL
from PIL import Image
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


In [3]:
mtcnn = MTCNN(keep_all=True, thresholds=[0.1, 0.1, 0.2], device=device)

In [4]:
img_list = glob('../input/data/train/images/*/*.jpg')
eval_list = glob('../input/data/eval/images/*.jpg')

In [5]:
def expand_bbox(bbox, ratio=0.2):
    x_min, y_min, x_max, y_max = bbox
    
    H = y_max - y_min
    W = x_max - x_min
    
    x_min = max(0, int(x_min - W*ratio/2))
    x_max = int(x_max + W*ratio/2)
    y_min = max(0, int(y_min - H*ratio/2))
    y_max = int(y_max + H*ratio/2)
    
    return x_min, y_min, x_max, y_max

In [6]:
def crop_bbox(img, bbox):
    if isinstance(img, PIL.JpegImagePlugin.JpegImageFile):
        img = np.array(img)
    return img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()

In [7]:
import os
import cv2

In [8]:
from tqdm import tqdm

In [9]:
batch_size = 128

In [None]:
# batch inference
for i in tqdm(range(0, len(img_list), batch_size)):
    batch_img_list = img_list[i : batch_size+i]
    batch_inputs = [Image.open(path) for path in batch_img_list]
    
    boxes, _ = mtcnn.detect(batch_inputs)
    
    for img_path, img, bbox in zip(batch_img_list, batch_inputs, boxes):
        try:
            bbox = expand_bbox(bbox[0])
            crop_face = crop_bbox(img, bbox)

            save_path = img_path.replace('input', 'face_input')
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            cv2.imwrite(save_path, crop_face[:,:,::-1]) # RGB -> BGR
        except Exception as e:
            print("Error:", img_path, e)
        
#     break

In [None]:
# batch inference
for i in tqdm(range(0, len(eval_list), batch_size)):
    batch_img_list = eval_list[i : batch_size+i]
    batch_inputs = [Image.open(path) for path in batch_img_list]
    
    boxes, _ = mtcnn.detect(batch_inputs)
    
    for img_path, img, bbox in zip(batch_img_list, batch_inputs, boxes):
        try:
            bbox = expand_bbox(bbox[0])
            crop_face = crop_bbox(img, bbox)

            save_path = img_path.replace('input', 'face_input')
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            cv2.imwrite(save_path, crop_face[:,:,::-1]) # RGB -> BGR
        except Exception as e:
            print("Error:", img_path, e)

 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 98/99 [18:31<00:11, 11.98s/it]