# Face alignment using state of the art model

[Papers With Code](https://paperswithcode.com/about) is awesome resource for Machine Learning information, both papers and code.

For this experiment I chose [leading method (as of 22.12.2020)](https://paperswithcode.com/paper/towards-fast-accurate-and-stable-3d-dense-1) on face alignment task.

Which is conveniently  avialable as [PyTorch model on GitHub](https://github.com/cleardusk/3DDFA_V2https://github.com/cleardusk/3DDFA_V2).

````bash
# that repository is prerequirement here
git clone https://github.com/cleardusk/3DDFA_V2.git
cd 3DDFA_V2
./build.sh
pip3 install -r requirements.txt
````


In [None]:
import os
import cv2
import PIL.Image
import torch
import numpy as np
from jetutils import SimpleTimer
from IPython.display import display


import cv2
import yaml

import sys
sys.path.append('3DDFA_V2')
from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX

from TDDFA import TDDFA
from utils.functions import draw_landmarks, cv_draw_landmark
from utils.render import render
from utils.depth import depth
from utils.pncc import pncc
from utils.pose import viz_pose

from torch2trt import TRTModule

from jetutils import SimpleTimer
from IPython.display import display
from sidecar import Sidecar
from torch2trt import TRTModule
import torchvision.transforms as transforms

Original ONNX runtime model processing time is ~45ms and face detection time ~100ms
rendering time being 50ms per image.

When converted to TensorRT model time drops to ~25ms per image. Still rendering and face detection take too much time, best frame rate being about 5/s. 

In [None]:
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ['OMP_NUM_THREADS'] = '4'
        
MODEL_PATH='models/mb1_120x120_trt.pth'
FACE_BOX_PATH='models/facebox_trt.pth'
MODEL_H = 120
MODEL_W = 120

cfg = yaml.load(open('3DDFA_V2/configs/mb1_120x120.yml'), Loader=yaml.SafeLoader)

#patch config paths
def patch_path(config, key):
    config[key] = '3DDFA_V2/{}'.format(config[key])

for key in ('checkpoint_fp', 'bfm_fp'):
    patch_path(cfg, key)

tddfa = TDDFA(gpu_mode=True, **cfg)
face_boxes = FaceBoxes_ONNX()

if not os.path.exists(MODEL_PATH):
    import torch2trt
    timer = SimpleTimer()
    with timer:
        trt_model = torch2trt.torch2trt(tddfa.model,
                                        [torch.zeros((1, 3, MODEL_H, MODEL_W)).cuda()],
                                        fp16_mode=True,
                                        max_workspace_size=1<<20)
    print('model converted to TensorRT in {} seconds'.format(timer.time))
    torch.save(trt_model.state_dict(), MODEL_PATH)        
    tddfa.model = trt_model
else:
    timer = SimpleTimer()
    with timer:
        trt_model = TRTModule()
        trt_model.load_state_dict(torch.load(MODEL_PATH))
    print('model loaded in {} seconds'.format(timer.time))
    tddfa.model = trt_model

# TODO: figure out proper convertion for face box detection model
#if not os.path.exists(FACE_BOX_PATH):
#    import torch2trt
#    timer = SimpleTimer()
#    with timer:
#        trt_model_fb = torch2trt.torch2trt(face_boxes.net.cuda(),
#                                          [torch.zeros((1, 3, MODEL_H, MODEL_W)).cuda()],
#                                          fp16_mode=True,
#                                          max_workspace_size=1<<20)
#    print('face box converted to TensorRT in {} seconds'.format(timer.time))
#    torch.save(trt_model_fb.state_dict(), FACE_BOX_PATH)        
#    face_boxes.net = trt_model_fb
#else:
#    timer = SimpleTimer()
#    with timer:
#        trt_model_fb = TRTModule()
#        trt_model_fb.load_state_dict(torch.load(FACE_BOX_PATH))
#    print('face box loaded in {} seconds'.format(timer.time))
#    face_boxes.net = trt_model_fb



In [None]:
import traitlets
import io
from tqdm.notebook import tqdm

class TDDFAProcess(traitlets.HasTraits):
    input_frame = traitlets.Any()
    output_frame = traitlets.Any()
    draw_original = traitlets.Bool(default_value=False)
    render_mode = traitlets.Unicode(default_value='mask')
    
    dense = traitlets.Bool(default_value=True)
    alpha = traitlets.Integer(default_value=60)

    def __init__(self, model, detector):

        self._model = model
        self._detector = detector
        self.debug_out = None
        self.output_frame = np.zeros((640, 480, 3), dtype=np.uint8)
    
    def process_frame(self, img):
        
        img = img.astype(np.uint8)
        timer = SimpleTimer()
        fpst = SimpleTimer()
        with fpst:
            with timer:
                boxes = self._detector(img)
            boxtime = timer.time
            with timer:
                param_lst, roi_box_lst = self._model(img, boxes)
            regtime = timer.time
            with timer:
                dense = self.dense if self.render_mode == 'landmarks' else True
                ver_lst = self._model.recon_vers(param_lst, roi_box_lst,
                                                 dense_flag=dense)
            recontime = timer.time
            pose_info = ''
            with timer:
                if self.render_mode == 'landmarks':
                    # landmark plotting is way too slow for real time...
                    img = draw_landmarks(img, ver_lst, show_flag=False,
                                         dense_flag=self.dense)
                    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            
                elif self.render_mode == 'mask':
                    img = render(img, ver_lst, self._model.tri,
                                 alpha=float(self.alpha/100), show_flag=False,
                                 with_bg_flag=self.draw_original)

                elif self.render_mode == 'depth':
                    img = depth(img, ver_lst, self._model.tri, show_flag=False,
                                with_bg_flag=self.draw_original)
                elif self.render_mode == 'pncc':
                    img = pncc(img, ver_lst, self._model.tri, show_flag=False,
                              with_bg_flag=self.draw_original)
                elif self.render_mode == 'pose':
                    if self.draw_original:
                        img2 = img.copy()
                    else:
                        img2 = np.zeros_like(img)
                    img, pose = viz_pose(img2, param_lst, ver_lst, show_flag=False)
                    pose_info = 'yaw   {}\npitch {}\nroll  {}'.format(pose[0], pose[1], pose[2])
            rendertime = timer.time
        fps = fpst.fps
        
        self.output_frame = img
        return 'faces {}\ndetect time {}\nmodel time  {}\nrecon time  {}\nrender time {}\nfps {}\n\n{}'.format(
                len(boxes), boxtime, regtime, recontime, rendertime, fps, pose_info)
    
    
    @classmethod
    def validate(cls, model, detector, img_path):
        instance = cls(model, detector)
        # validate all rendering modes
        instance.draw_original = True
        
        for mode in tqdm(('landmarks', 'mask', 'depth', 'pncc', 'pose')):
            data  = cv2.imread(img_path)
            instance.render_mode = mode
            print(instance.process_frame(data))
            display(PIL.Image.fromarray(cv2.cvtColor(instance.output_frame, cv2.COLOR_BGR2RGB))) 

TDDFAProcess.validate(tddfa,face_boxes, 'jetson-inference/data/images/humans_6.jpg')
face_model = TDDFAProcess(tddfa, face_boxes)

In [None]:
from jetutils import GstCamera, bgr8_to_jpeg
camera = GstCamera()

My standard display and control box with controls for different modes and values:

In [None]:
import ipywidgets
image_original = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
image_processed = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
image_original.value = bgr8_to_jpeg(np.zeros((camera.height, camera.width, 3), dtype=np.uint8))
image_processed.value = bgr8_to_jpeg(np.zeros((camera.height, camera.width, 3), dtype=np.uint8))


debug_out = ipywidgets.Textarea(value='',
                                disabled=True,
                                layout=ipywidgets.Layout(width='640px', height='520px'))
face_model.debug_out = debug_out

images_out = ipywidgets.HBox([image_original, image_processed])

select_outmode = ipywidgets.ToggleButtons(options=['landmarks', 'mask', 'depth', 'pncc', 'pose'],
                                     value='mask', description='render',
                                     style={'description_width': 'initial'})

select_original = ipywidgets.ToggleButton(value=False, description='over orginal')
select_dense = ipywidgets.ToggleButton(value=True, description='dense')

alpha_slider = ipywidgets.IntSlider(value=50, min=0, max=100,
                                         description='mask alpha %',
                                         style={'description_width': 'initial'})

control_box = ipywidgets.HBox([select_outmode,
                               ipywidgets.Label(value=' '),
                               select_original,
                               select_dense,
                               alpha_slider])

all_box = ipywidgets.VBox([images_out, control_box, debug_out])
_sidecar = Sidecar(title='output')
with _sidecar:
    display(all_box)

In [None]:
import IPython
ipython = IPython.get_ipython()
traitlets.dlink((camera, 'value'), (image_original, 'value'), transform=bgr8_to_jpeg)
traitlets.dlink((face_model,'output_frame'), (image_processed, 'value'), transform=bgr8_to_jpeg)

traitlets.dlink((select_original, 'value'), (face_model, 'draw_original'))
traitlets.dlink((select_outmode, 'value'), (face_model, 'render_mode'))
traitlets.dlink((select_dense, 'value'), (face_model, 'dense'))
traitlets.dlink((alpha_slider, 'value'), (face_model, 'alpha'))

# this model is too slow to be run each time camera streams new image
# so get new image only after processing old one
while True: 
    ipython.kernel.do_one_iteration()
    frame = camera.read()
    debug_out.value = face_model.process_frame(frame)
    
#camera.observe(process, names='value')
#camera.running = True

In [None]:
camera.unobserve_all()
camera.running = False

In [None]:
del camera