In [None]:
%cd ..
%reload_ext autoreload
%autoreload 2

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pyrender
from scipy import optimize

from env import DATA_PATH
from face_reconstruction.graphics import SimpleImageRenderer, draw_pixels_to_image, cv2_to_plt
from face_reconstruction.landmarks import load_bfm_landmarks, detect_landmarks
from face_reconstruction.model import BaselFaceModel
from face_reconstruction.optim import SparseOptimization
from face_reconstruction.plots import PlotManager

# 1. Face Model

In [None]:
bfm = BaselFaceModel.from_h5("model2019_face12.h5")
bfm_landmarks = load_bfm_landmarks("model2019_face12_landmarks_v2")
bfm_landmark_indices = list(bfm_landmarks.values())

In [None]:
n_shape_coefficients = bfm.get_n_shape_coefficients()
n_expression_coefficients = bfm.get_n_expression_coefficients()
n_color_coefficients = bfm.get_n_color_coefficients()

# 2. Input image

In [None]:
img_name = "trump.jpg"

In [None]:
img_path = f"{DATA_PATH}/Keypoint Detection/{img_name}"
img = cv2.imread(img_path)

In [None]:
landmarks_img = detect_landmarks(img)

In [None]:
img_width = img.shape[1]
img_height = img.shape[0]

# 3. Setup rendering pipeline

In [None]:
perspective_camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
projection_matrix = perspective_camera.get_projection_matrix(width=img_width, height=img_height)
initial_camera_pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, -300], [0, 0, 0, 1]]) # position camera just in front of face

In [None]:
renderer = SimpleImageRenderer(projection_matrix, img_width, img_height)

# 4. Optimization

## 4.1. Joint optimization for face parameters and pose

### 4.1.1 Setup Joint Optimization

In [None]:
n_params_shape = 20 # 20
n_params_expression = 10 # 10
weight_shape_params = 10000 # 10000
weight_expression_params = 1000 # 1000

In [None]:
sparse_optimization = SparseOptimization(bfm, n_params_shape, n_params_expression, weight_shape_params=weight_shape_params, weight_expression_params=weight_expression_params)
loss = sparse_optimization.create_loss(renderer, bfm_landmark_indices, landmarks_img)

In [None]:
initial_params = sparse_optimization.create_parameters(
    [0 for _ in range(n_shape_coefficients)],
    [0 for _ in range(n_expression_coefficients)],
    initial_camera_pose
)

In [None]:
assert all(sparse_optimization.create_parameters_from_theta(initial_params.to_theta()).to_theta() == initial_params.to_theta()), "OptimizationParameters is ill-defined"

In [None]:
initial_params.to_theta()

### 4.1.2. Run Joint Optimization

In [None]:
# This typically takes 20 seconds
result = optimize.least_squares(loss, initial_params.to_theta(), max_nfev=100, verbose=2)

In [None]:
# Found parameters
result.x

In [None]:
# Final cost
result.cost

In [None]:
result.message

In [None]:
params = sparse_optimization.create_parameters_from_theta(result.x)

# 4.2. Alternating Optimization

## 4.2.1. Setup Alternating Optimization

In [None]:
n_iterations_face = 2 * (n_params_shape + n_params_expression)
n_iterations_camera = 20
n_dual_iterations = 10

In [None]:
face_optimizer = SparseOptimization(bfm, n_params_shape, n_params_expression, fix_camera_pose=True, weight_shape_params=weight_shape_params, weight_expression_params=weight_expression_params)
camera_optimizer = SparseOptimization(bfm, 0, 0, fix_camera_pose=False)

In [None]:
params = face_optimizer.create_parameters(
    [0 for _ in range(n_shape_coefficients)],
    [0 for _ in range(n_expression_coefficients)],
    initial_camera_pose
)

In [None]:
face_optimizer_loss = face_optimizer.create_loss(renderer, bfm_landmark_indices, landmarks_img, fixed_camera_pose=initial_camera_pose)
camera_optimizer_loss = camera_optimizer.create_loss(renderer, bfm_landmark_indices, landmarks_img, fixed_shape_coefficients=initial_params.shape_coefficients, fixed_expression_coefficients=initial_params.expression_coefficients)

### 4.2.2. Run Alternating Optimization

In [None]:
for iteration in range(n_dual_iterations):
    camera_optimizer_loss = camera_optimizer.create_loss(renderer, bfm_landmark_indices, landmarks_img, fixed_shape_coefficients=params.shape_coefficients, fixed_expression_coefficients=params.expression_coefficients)
    params = camera_optimizer.create_parameters(camera_pose=face_optimizer_loss.fixed_camera_pose)
    result = optimize.least_squares(camera_optimizer_loss, params.to_theta(), max_nfev=n_iterations_camera, verbose=2)
    print(result.cost)
    params = camera_optimizer.create_parameters_from_theta(result.x)
    
    face_optimizer_loss = face_optimizer.create_loss(renderer, bfm_landmark_indices, landmarks_img, fixed_camera_pose=params.camera_pose)
    params = face_optimizer.create_parameters(shape_coefficients=camera_optimizer_loss.fixed_shape_coefficients, expression_coefficients=camera_optimizer_loss.fixed_expression_coefficients)
    result = optimize.least_squares(face_optimizer_loss, params.to_theta(), max_nfev=n_iterations_face, verbose=2)
    print(result.cost)
    params = face_optimizer.create_parameters_from_theta(result.x)

In [None]:
costs_dual = [3.2807e+05, 4.8681e+04, 3.9522e+04, 8.1128e+03, 5.9347e+03, 3.2546e+03, 3.1953e+03, 3.1863e+03, 3.1856e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 3.1855e+03, 4.0271e+02, 3.8531e+02, 3.7645e+02, 3.7607e+02, 3.7505e+02, 3.7504e+02, 3.7504e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.7503e+02, 3.5441e+02, 3.5427e+02, 3.5424e+02, 3.5423e+02, 3.5423e+02, 3.5423e+02, 3.5423e+02, 3.5423e+02, 3.5423e+02, 3.5423e+02, 3.5423e+02, 3.4254e+02, 3.3994e+02, 3.3827e+02, 3.3746e+02, 3.3719e+02, 3.3695e+02, 3.3687e+02, 3.3684e+02, 3.3683e+02, 3.3683e+02, 3.3682e+02, 3.3682e+02, 3.3682e+02, 3.3682e+02, 3.3682e+02, 3.3682e+02, 3.3089e+02, 3.2987e+02, 3.2927e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2926e+02, 3.2420e+02, 3.2355e+02, 3.2353e+02, 3.2353e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.2352e+02, 3.1899e+02, 3.1864e+02, 3.1860e+02, 3.1857e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1856e+02, 3.1630e+02, 3.1606e+02, 3.1591e+02, 3.1586e+02, 3.1584e+02, 3.1583e+02, 3.1583e+02, 3.1583e+02, 3.1583e+02, 3.1583e+02, 3.1583e+02, 3.1583e+02, 3.1292e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.1290e+02, 3.0767e+02, 3.0446e+02, 3.0440e+02, 3.0427e+02, 3.0425e+02, 3.0424e+02, 3.0424e+02, 3.0423e+02, 3.0423e+02, 3.0423e+02, 3.0423e+02, 3.0423e+02, 3.0087e+02, 3.0081e+02, 3.0081e+02, 3.0081e+02, 3.0081e+02, 3.0081e+02, 3.0081e+02, 3.0081e+02, 3.0081e+02, 2.9878e+02, 2.9860e+02, 2.9854e+02, 2.9851e+02, 2.9850e+02, 2.9849e+02, 2.9848e+02, 2.9847e+02, 2.9847e+02, 2.9847e+02, 2.9847e+02, 2.9847e+02, 2.9847e+02, 2.9847e+02, 2.9847e+02, 2.9613e+02, 2.9613e+02, 2.9613e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9612e+02, 2.9300e+02, 2.9265e+02, 2.9253e+02, 2.9250e+02, 2.9249e+02, 2.9248e+02, 2.9248e+02, 2.9248e+02, 2.9248e+02, 2.9248e+02, 2.9248e+02, 2.9248e+02, 2.9248e+02, 2.9018e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.9017e+02, 2.8814e+02, 2.8810e+02, 2.8800e+02, 2.8799e+02, 2.8798e+02, 2.8798e+02, 2.8798e+02, 2.8798e+02, 2.8798e+02, 2.8798e+02, 2.8631e+02, 2.8599e+02, 2.8596e+02, 2.8596e+02, 2.8596e+02, 2.8596e+02, 2.8596e+02, 2.8596e+02, 2.8596e+02, 2.8488e+02, 2.8433e+02, 2.8430e+02, 2.8427e+02, 2.8426e+02, 2.8426e+02, 2.8426e+02, 2.8426e+02, 2.8426e+02, 2.8426e+02, 2.8426e+02, 2.8426e+02, 2.8253e+02, 2.8252e+02, 2.8252e+02, 2.8252e+02, 2.8252e+02, 2.8252e+02, 2.8252e+02, 2.8252e+02, 2.8168e+02, 2.8125e+02, 2.8119e+02, 2.8114e+02, 2.8111e+02, 2.8110e+02, 2.8109e+02, 2.8109e+02, 2.8109e+02, 2.8109e+02, 2.8109e+02]
costs_joint = [3.2807e+05, 4.3899e+04, 2.0022e+04, 1.9179e+03, 7.7233e+02, 4.3539e+02, 3.6307e+02, 2.9366e+02, 2.5590e+02, 2.5270e+02, 2.5069e+02, 2.5013e+02, 2.4999e+02, 2.4998e+02, 2.4996e+02]
plt.title("Joint vs Alternating Optimization")
plt.plot(costs_dual[20:], label='Alternating Optimization')
plt.plot(costs_joint[5:], label='Joint Optimization')
plt.ylabel("Cost")
plt.xlabel("Iteration")
plt.legend()
plt.show()

In [None]:
sparse_optimization = SparseOptimization(bfm, n_params_shape, n_params_expression, weight_shape_params=weight_shape_params, weight_expression_params=weight_expression_params)
params = sparse_optimization.create_parameters(shape_coefficients=params.shape_coefficients, expression_coefficients=params.expression_coefficients, camera_pose=face_optimizer_loss.fixed_camera_pose)

# 5. Draw mask on input image

In [None]:
plot_manager = PlotManager("sparse_reconstruction")

In [None]:
face_mesh = bfm.draw_sample(
        shape_coefficients=params.shape_coefficients, 
        expression_coefficients=params.expression_coefficients, 
        color_coefficients=[0 for _ in range(n_color_coefficients)])
face_pixels = renderer.project_points(params.camera_pose, face_mesh.vertices)

In [None]:
img = cv2.imread(img_path)
img = cv2_to_plt(img)

In [None]:
pixels_bfm_landmarks = renderer.project_points(params.camera_pose, np.array(face_mesh.vertices)[bfm_landmark_indices])

In [None]:
#draw_pixels_to_image(img, face_pixels, color=1)
draw_pixels_to_image(img, landmarks_img, color=[0, 255, 0])
draw_pixels_to_image(img, pixels_bfm_landmarks, color=[255, 0, 0])

In [None]:
plt.figure(figsize=(20, 14))
plt.imshow(img)
plot_manager.save_current_plot(f"landmarks_fitting_{img_name}.pdf")
plt.show()

# 6. Render full mask

## 6.1 Setup scene

In [None]:
perspective_camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=img_width / img_height)
directional_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0)

In [None]:
face_trimesh = bfm.convert_to_trimesh(face_mesh)

In [None]:
scene = pyrender.Scene()
scene.add(pyrender.Mesh.from_trimesh(face_trimesh), pose=params.camera_pose)
scene.add(perspective_camera)
scene.add(directional_light)

## 6.2 Interactive rendering (face only)

In [None]:
pyrender.Viewer(scene, use_raymond_lighting=True, viewport_size=(img_width, img_height))

## 6.3 Render face onto input image

In [None]:
r = pyrender.OffscreenRenderer(img_width, img_height)

color, depth = r.render(scene)
r.delete()

In [None]:
depth_mask = depth != 0

In [None]:
img = cv2.imread(img_path)
img = cv2_to_plt(img)
img[depth_mask] = color[depth_mask] 

In [None]:
plt.figure(figsize=(img_width / 50, img_height / 50))
plt.imshow(img)
plot_manager.save_current_plot(f"mask_fitting_{img_name}.pdf")
plt.show()