# Capturing neutral face video

In [1]:
import matplotlib.pyplot as plt
%matplotlib notebook

import os, sys
import yaml
from argparse import ArgumentParser

import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch

import pdb
#import pyvirtualcam
import time
from calibration.undistort import undistort
import cv2

import face_alignment
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cuda', face_detector='blazeface')

USE_RECORDED_VIDEO = True
MERGE_HEIGHT = 128
def crop_image(image):
    return image[:,80:-80,:]

if sys.version_info[0] < 3:
    raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")

def draw_marks(image, marks, color=(0, 255, 0)):
    """
    Draw the facial landmarks on an image
    Parameters
    ----------
    image : np.uint8
        Image on which landmarks are to be drawn.
    marks : list or numpy array
        Facial landmark points
    color : tuple, optional
        Color to which landmarks are to be drawn with. The default is (0, 255, 0).
    Returns
    -------
    None.
    """
    for mark in marks:
        cv2.circle(image, (mark[0], mark[1]), 2, color, -1, cv2.LINE_AA)    

def pre_process_frame(frame, crop = True, cvt_color = True):
    #frame = frame[:,80:-80,:]
    if crop:
        frame = crop_image(frame)
    frame = cv2.resize(frame,(256,256))
    if cvt_color:
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)    
    frame = [resize(img, (256, 256))[..., :3] for img in [frame]]
    return frame

In [2]:
source_image = (imageio.imread('./init_img2.png')).astype(np.uint8)
reader2 = cv2.VideoCapture('./init_video.mp4')
driving_video = []
regular_video = []


while(1):
    ret,frame = reader2.read()
    if(ret):
        regular_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    else:
        break

## first image of the natural blinking video
source_image = resize(source_image, (256, 256))[..., :3]
## resize the natural blinking video 
regular_video = [resize(frame, (256, 256))[..., :3] for frame in regular_video]

if not USE_RECORDED_VIDEO:
    ## make sure camera is working 
    cap = cv2.VideoCapture(0)
    res, frame = cap.read()
    while(res is not True):
        res, frame = cap.read()
        print('retry')
        time.sleep(0.3)
    driving_video = pre_process_frame(frame)
else:
    saved_video_reader = cv2.VideoCapture('./video_headset_on.mp4')
    saved_video = []
    while(1):
        ret,frame = saved_video_reader.read()
        if(ret):
            saved_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        else:
            break
    driving_video = pre_process_frame(saved_video[0], crop =False, cvt_color = False)

print(len(regular_video), len(saved_video))

439 602


In [3]:
i = 0
landmark_map = []
face_loc = np.zeros(5)
while i < len(regular_video):
    frame = (regular_video[i] * 255)
    preds, _, faces = fa.get_landmarks(frame, return_bboxes=True)
    if preds:
        landmark_map.append([preds[0], frame / 255])
        face_loc += faces[0]
    else:
        print("No face", i)
    i += 1

face_loc = face_loc[:4] / len(landmark_map)
print(faces)
print(face_loc)
if(preds):
    # Jawline, and mouth
    headset_features = np.concatenate((preds[0].astype(int)[:17],preds[0].astype(int)[49:61]))
    draw_marks(frame,headset_features,color = (255,0,0))

plt.imshow(frame/255)

[[ 81.79329681  65.30234528 193.44833374 176.95756531   0.91637206]]
[ 78.78149136  62.49193937 190.78902548 174.50008773]


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f2484131df0>

In [6]:
def landmark_diff(la, lb):
    mouth_a = la[49:61]
    mouth_b = lb[49:61]
    mouth_a_normalized = mouth_a - np.mean(mouth_a, axis=1).reshape(-1, 1)
    mouth_b_normalized = mouth_b - np.mean(mouth_b, axis=1).reshape(-1, 1)
    diff = (mouth_a - mouth_b)**2
    return np.sum(diff)

def best_next_frame(current, captured):
    la, img = landmark_map[current]
    preds = fa.get_landmarks(captured, detected_faces=[[*face_loc, 1]])
    if preds is None:
        return current
    lb = preds[0]

    best_i = current
    best_score = landmark_diff(la, lb)
    for i in range(len(landmark_map)):
        score = landmark_diff(landmark_map[i][0], lb)
        if score < best_score:
            best_score = score
            best_i = i
    return best_i, la, lb

In [12]:
## Add the neutral blinking video to the top part
driving_video[0][:MERGE_HEIGHT,:] = regular_video[0][:MERGE_HEIGHT,:]
# cv2.imshow(winname="Face", mat=cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))


counter = 0
cur_frame = 0
#with pyvirtualcam.Camera(width=256, height=256, fps=30,device='/dev/video5') as cam:    
while(1):
    start = time.time()
    if not USE_RECORDED_VIDEO:
        res, frame = cap.read()
        if not res:
            print('no frame')
            continue
        cv2.imwrite("normalframe.jpg", frame);
        frame = pre_process_frame(frame)
    else:
        frame = saved_video[counter%len(saved_video)]
        frame = pre_process_frame(frame, crop = False, cvt_color = False)
        counter = (counter+1)%len(regular_video)
    #frame[0][:MERGE_HEIGHT,:] = regular_video[counter%len(regular_video)][:MERGE_HEIGHT,:]   
    cv2.imshow("input frame",frame[0])

    ## Add the neutral blinking video to the top part
    next_frame, la, lb = best_next_frame(cur_frame, frame[0] * 255)
    cur_frame = next_frame
    frame[0] = landmark_map[next_frame][1].copy()
    
    #f1 = np.concatenate((la.astype(int)[:17],la.astype(int)[49:61]))
    #draw_marks(frame[0],f1,color = (255,0,0))
    #f2 = np.concatenate((lb.astype(int)[:17],lb.astype(int)[49:61]))
    #draw_marks(frame[0],f2,color = (0,0,255))
    #frame[0][:MERGE_HEIGHT,:] = regular_video[counter%len(regular_video)][:MERGE_HEIGHT,:]   

    ## debugging visualization
    single_frame = frame[0]
    #draw_marks(single_frame, kp_driving[0].astype(int))
    tmp = cv2.resize(single_frame,(720,720))
    tmp[:,:,[0,2]] = tmp[:,:,[2,0]] 
    cv2.imshow("split frame",tmp)
    ch = cv2.waitKey(1)
    if ch == ord('q'):
        cv2.destroyWindow("split frame")
        cv2.destroyWindow("input frame")
        break
