In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np
from glob import glob 
from tqdm import tqdm
from skimage import transform as trans

import sys
sys.path.append('./RetinaFace')
from retinaface import RetinaFace

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### Image collecting

In [None]:
images  = glob('/mnt/hdd2/David/Dataset/bald_for_gan/*.jpg')
images += glob('/mnt/hdd2/David/Dataset/bald_for_gan/*.jpeg')
images += glob('/mnt/hdd2/David/Dataset/bald_for_gan/*.png')
print('Images:', len(images))

In [None]:
img = cv2.imread(images[500])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)
plt.show()

### Load detection model

In [None]:
gpuid = 0
detector = RetinaFace('./RetinaFace/retinaface-R50/R50', 0, gpuid, 'net3')

In [None]:
minsize = 100 # minimum size of face
thresh = 0.8
scales = [1.0]
factor = 0.709 # scale factor
image_size = [256,256]
src = np.array([
  [30.2946, 51.6963],
  [65.5318, 51.5014],
  [48.0252, 71.7366],
  [33.5493, 92.3655],
  [62.7299, 92.2041] ], dtype=np.float32 )

src[:,0] += 8.0
src[:,0] += 15.0
src[:,1] += 30.0
src /= 112
src *= 200

In [None]:
def align_face(img):
    bounding_boxes, points = detector.detect(img, thresh, scales=scales, do_flip=True)

    result_faces = []
    
    if bounding_boxes.shape[0]>0:
        det = bounding_boxes[:,0:4]
        for i in range(det.shape[0]):
            _det = det[i]
            dst = points[i]
            
            
            tform = trans.SimilarityTransform()
            tform.estimate(dst, src)
            M = tform.params[0:2,:]
            warped = cv2.warpAffine(img,M,(image_size[1],image_size[0]), borderValue = 0.0)
            
            result_faces.append(warped)
            
    return result_faces 

### Image processing

In [None]:
num = 0
for filename in tqdm(images):
    try:
        img = cv2.imread(filename)
        face = align_face(img)[0]
        dst_filename = '/mnt/ssd2/Datasets/bald_for_GAN_test/'
        dst_filename += str(num) + '.jpg'
        num += 1
        cv2.imwrite(dst_filename, face)
    except:
        pass