Provides easy ways to visualize and plot frames that were run via `video_reconstruction`

In [None]:
%cd ..
%reload_ext autoreload
%autoreload 2

In [None]:
import pyrender
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

from face_reconstruction.plots import PlotManager, plot_params
from face_reconstruction.model import BaselFaceModel
from face_reconstruction.pipeline import BFMPreprocessor
from face_reconstruction.utils.io import list_file_numbering
from face_reconstruction.landmarks import detect_landmarks
from face_reconstruction.graphics import draw_pixels_to_image
from env import PLOTS_PATH

# 1. Load run

In [None]:
run_id = 23

In [None]:
plot_manager = PlotManager(f"video_reconstruction/run-{run_id}")
preprocessor = BFMPreprocessor()

# 2. Helper functions

In [None]:
def load_params(frame_id):
    return plot_manager.load_params(f"params_{frame_id:04d}", preprocessor.bfm)

In [None]:
def load_param_history(frame_id):
    return plot_manager.load_param_history(f"param_history_{frame_id:04d}", preprocessor.bfm)

In [None]:
def plot_reconstruction_error(frame_id):
    preprocessor.load_frame(frame_id)
    preprocessor.to_3d()
    error = preprocessor.plot_reconstruction_error(load_params(frame_id))
    plt.xlabel(f"Mean Reconstruction Error: {error:.3f}")

In [None]:
def plot_rgb(frame_id):
    preprocessor.load_frame(frame_id)
    plt.imshow(preprocessor.img)

In [None]:
def plot_rgb_with_landmarks(frame_id):
    preprocessor.load_frame(frame_id)
    landmarks_img = np.array(preprocessor.img)
    landmarks = detect_landmarks(landmarks_img)
    draw_pixels_to_image(landmarks_img, landmarks, color=[0, 255, 0])
    plt.imshow(landmarks_img)

In [None]:
def plot_depth(frame_id):
    preprocessor.load_frame(frame_id)
    plt.imshow(preprocessor.depth_img)

In [None]:
def plot_mask(frame_id):
    preprocessor.load_frame(frame_id)
    preprocessor.to_3d()
    img_with_mask = preprocessor.render_onto_img(load_params(frame_id))
    plt.imshow(img_with_mask)

In [None]:
def generate_param_history_video(frame_id):
    ph = load_param_history(frame_id)
    preprocessor.load_frame(frame_id)
    preprocessor.to_3d()
    preprocessor.store_param_history(plot_manager, f"param_history/{frame_id}/", ph)
    plot_manager.cd(f"param_history/{frame_id}/").generate_video('iteration_', '.jpg')

In [None]:
def get_frames():
    return list_file_numbering(f"{PLOTS_PATH}/video_reconstruction/run-{run_id}", 'frame_', '.jpg')

# 3. Example plots

In [None]:
plot_reconstruction_error(0)

In [None]:
plot_rgb(1)

In [None]:
plot_depth(1)

In [None]:
plot_mask(0)

# 4. Generate final plots for all frames

In [None]:
gs = gridspec.GridSpec(2, 4, height_ratios=[3, 1]) 
gs2 = gridspec.GridSpec(2, 2, height_ratios=[3, 1]) 

In [None]:
for frame_id in get_frames():

    fig = plt.figure(figsize=(20, 10))

    fig.add_subplot(gs[0])
    plt.title('RGB Input')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)
    plot_rgb(frame_id)

    fig.add_subplot(gs[1])
    plt.title('Depth Input')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)
    plot_depth(frame_id)

    fig.add_subplot(gs[2])
    plt.title('Fitted Mask')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)
    plot_mask(frame_id)

    fig.add_subplot(gs[3])
    plt.title('Reconstruction Error')
    plt.xlim(0, 150)
    plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)
    plot_reconstruction_error(frame_id)

    fig.add_subplot(gs2[2])
    plt.title("Shape Coefficients")
    plt.ylim(-10, 10)
    plot_params(load_params(frame_id).shape_coefficients)

    fig.add_subplot(gs2[3])
    plt.title("Expression Coefficients")
    plt.ylim(-10, 10)
    plot_params(load_params(frame_id).expression_coefficients, color='orange')

    plot_manager.save_current_plot(f"final/frame_{frame_id:04d}.jpg")

    plt.show()
    plt.close()

# 5. Param History

In [None]:
generate_param_history_video(0)