### Initialize

In [1]:
import time

import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

'''
Common modules
'''

from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth


'''
Util functions
'''

def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)


Compile complete


### Load input

In [2]:
# global variables

video_dir = '/mnt/e/data/vox2/videos/512/id00774/'
output_dir = '/mnt/c/Users/mjh/Downloads/out_test/'
read_2_gpu_batch_size = 2048
gpu_batch_size = 16
file_processed_indx = 0
process_queue = torch.Tensor().to(device)

In [3]:
# Read all video files
video_paths = sorted(glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True))  # Search recursively

print(f"Found {len(video_paths)} video files.")

video_frames = read_multiple_videos(video_paths, num_threads=4)
all_frames = []
total_frames = 0
video_lengths = []
vid_keys= []

for video_path, frames in video_frames:
    all_frames.extend(frames)
    frame_count = len(frames)
    total_frames += frame_count
    video_lengths.append(frame_count)
    print(video_path)
    clip_id = video_path.split('/')[-1].split('.')[0]
    url_id = video_path.split('/')[-2]
    vid_keys.append(url_id+'+'+clip_id+'+'+str(frame_count))
    print(f"Processed video: {clip_id} - {url_id}, frames: {frame_count}")

print(f"\nTotal frames across all videos: {total_frames}")
print(f"Video lengths: {video_lengths}")

# Convert to numpy array
all_frames = np.array(all_frames)

print(f"Shape of concatenated array: {all_frames.shape}")

Found 294 video files.
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00001.mp4
Processed video: 00001 - 0Noa8soq03Y, frames: 127
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00002.mp4
Processed video: 00002 - 0Noa8soq03Y, frames: 109
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00003.mp4
Processed video: 00003 - 0Noa8soq03Y, frames: 456
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00004.mp4
Processed video: 00004 - 0Noa8soq03Y, frames: 114
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00005.mp4
Processed video: 00005 - 0Noa8soq03Y, frames: 131
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00006.mp4
Processed video: 00006 - 0Noa8soq03Y, frames: 548
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00007.mp4
Processed video: 00007 - 0Noa8soq03Y, frames: 110
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00008.mp4
Processed video: 00008 - 0Noa8soq03Y, frames: 223
/mnt/e/data/vox2/videos/512/id00774/0Noa8soq03Y/00009.mp4
Processed video: 00009 - 0Noa8soq03Y, frames: 169
/mnt/

In [4]:
read_data_2_gpu = True
total_frames = len(all_frames)
read_data_2_gpu_pointer = 0
all_process_time = 0
all_write_time = 0

while read_data_2_gpu:
    start_process_time = time.time()
    # Calculate the batch size for this iteration
    current_batch_size = min(read_2_gpu_batch_size, total_frames - read_data_2_gpu_pointer)

    # Read the next batch of frames
    batch_input = all_frames[read_data_2_gpu_pointer:read_data_2_gpu_pointer + current_batch_size]
    batch_input = prepare_videos_(batch_input, device)

    # Process the batch in mini-batches
    # prepare mini-batches varaiables
    mini_batch_start = 0
    all_info = []
    while mini_batch_start < batch_input.shape[0]:
        mini_batch_end = min(mini_batch_start + gpu_batch_size, batch_input.shape[0])
        mini_batch = batch_input[mini_batch_start:mini_batch_end]

        x_info = get_kp_info(mini_batch)

        # Concatenate the tensors
        concat_tensor = torch.cat([
            x_info['exp'].reshape(mini_batch_end - mini_batch_start, -1),
            x_info['t'],
            torch.cat([x_info['pitch'], x_info['yaw'], x_info['roll']], dim=1),
        ], dim=1)

        all_info.append(concat_tensor)

        mini_batch_start = mini_batch_end
    all_info_tensor = torch.cat(all_info, dim=0)
    all_process_time = time.time() - start_process_time

    start_write_time = time.time()
    # add to process queue
    process_queue = torch.cat((process_queue, all_info_tensor), dim=0)
    write_to_disk_count = 0
    write_to_disk_flag = True
    while write_to_disk_flag and len(vid_keys) > 0:
        current_vid_key = vid_keys[0]
        current_frame_count = video_lengths[0]

        if len(process_queue) >= current_frame_count:
            # We have enough frames to write a complete video file
            video_tensor = process_queue[:current_frame_count]
            # Save tensor to disk
            # save_path = os.path.join(output_dir, f"{current_vid_key}.pt")
            # torch.save(video_tensor, save_path)
            # Save tensor to disk using numpy
            save_path = os.path.join(output_dir, f"{current_vid_key}.npy")
            np.save(save_path, video_tensor.cpu().numpy())

            # Remove processed frames from queue
            process_queue = process_queue[current_frame_count:]
            # Remove processed video key
            vid_keys.pop(0)
            video_lengths.pop(0)

            write_to_disk_count += 1
        else:
            # Not enough frames in the queue, exit the loop
            write_to_disk_flag = False
    all_write_time = time.time() - start_write_time
    # Update pointers and counters
    read_data_2_gpu_pointer += current_batch_size
    file_processed_indx += current_batch_size

    if read_data_2_gpu_pointer >= total_frames:
        read_data_2_gpu = False

    print(f"Processed {current_batch_size} frames in {all_process_time:.2f} seconds. Saved {write_to_disk_count} video files in {all_write_time:.2f} seconds.")

  return F.conv2d(input, weight, bias, self.stride,


Processed 2048 frames in 2.95 seconds. Saved 9 video files in 0.15 seconds.
Processed 2048 frames in 1.64 seconds. Saved 10 video files in 0.04 seconds.
Processed 2048 frames in 1.65 seconds. Saved 10 video files in 0.15 seconds.
Processed 2048 frames in 1.61 seconds. Saved 11 video files in 0.04 seconds.
Processed 2048 frames in 1.57 seconds. Saved 8 video files in 0.03 seconds.
Processed 2048 frames in 1.56 seconds. Saved 9 video files in 0.09 seconds.
Processed 2048 frames in 1.60 seconds. Saved 7 video files in 0.02 seconds.
Processed 2048 frames in 1.62 seconds. Saved 8 video files in 0.07 seconds.
Processed 2048 frames in 1.61 seconds. Saved 8 video files in 0.11 seconds.
Processed 2048 frames in 1.67 seconds. Saved 6 video files in 0.02 seconds.
Processed 2048 frames in 1.65 seconds. Saved 7 video files in 0.04 seconds.
Processed 2048 frames in 1.61 seconds. Saved 9 video files in 0.13 seconds.
Processed 2048 frames in 1.61 seconds. Saved 6 video files in 0.03 seconds.
Processed

Read json ( input UNION with audio)

In [11]:
import json
import os

json_path = '/mnt/c/Users/mjh/Downloads/output_union_512.json'
root_dir = '/mnt/e/data/vox2/videos/512/'  # Replace this with your actual root directory

# Load the JSON file
with open(json_path, 'r') as f:
    data = json.load(f)

video_paths = {}

# Iterate through the JSON structure
for first_level_key in data:
    video_paths[first_level_key] = []
    for second_level_key in data[first_level_key]:
        for third_level_key in data[first_level_key][second_level_key]:
            third_level_key = third_level_key.split('.')[0]
            # Construct the video path
            video_path = os.path.join(root_dir, first_level_key, second_level_key, f"{third_level_key}.mp4")
            video_paths[first_level_key].append(video_path)

# Print the results
print(f"Generated {len(video_paths)} video paths.")
print(len(video_paths['id00774']))

Generated 4617 video paths.
279


### Data distribution

In [None]:
driving_rgb_lst = all_frames[2000:2500]
I_d_lst = prepare_videos_(driving_rgb_lst, device)
I_d_lst = I_d_lst.unsqueeze(1)
# I_d_lst = I_d_lst[0 : video_lengths[0]]

In [None]:
I_d_lst.shape

In [None]:
batch_size = 16

In [None]:
batch_start = 0
all_info = []
all_pose = []
all_t = []
all_scale = []
all_exp = []

def euler_to_quaternion(pitch, yaw, roll):
    cy = torch.cos(yaw * 0.5)
    sy = torch.sin(yaw * 0.5)
    cp = torch.cos(pitch * 0.5)
    sp = torch.sin(pitch * 0.5)
    cr = torch.cos(roll * 0.5)
    sr = torch.sin(roll * 0.5)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy

    return torch.cat([w, x, y, z], dim=-1)

def encode_euler_angles(pitch, yaw, roll):
    return euler_to_quaternion(pitch, yaw, roll)

while True:
    batch_end = min(batch_start + batch_size, I_d_lst.shape[0])
    I_d = I_d_lst[batch_start:batch_end]
    I_d = I_d.squeeze(1)
    x_info = get_kp_info(I_d)
    # R_i = get_rotation_matrix(x_info['pitch'], x_info['yaw'], x_info['roll'])
    # Encode the Euler angles
    # encoded_angles = encode_euler_angles(x_info['pitch'], x_info['yaw'], x_info['roll'])


    # Concatenate the tensors
    concat_tensor = torch.cat([
        # encoded_angles,
        torch.cat([x_info['pitch'], x_info['yaw'], x_info['roll']], dim=1),
        x_info['t'],
        # x_info['scale'],
        x_info['exp'].reshape(batch_end - batch_start, -1),
    ], dim=1)

    all_info.append(concat_tensor)
    # all_pose.append(encoded_angles)
    all_t.append(x_info['t'])
    all_scale.append(x_info['scale'])
    all_exp.append(x_info['exp'])

    if batch_end == I_d_lst.shape[0]:
        break

    batch_start = batch_end

# Concatenate all batches
all_info_tensor = torch.cat(all_info, dim=0)
all_pose_tensor = torch.cat(all_pose, dim=0)
all_t_tensor = torch.cat(all_t, dim=0)
all_scale_tensor = torch.cat(all_scale, dim=0)
all_exp_tensor = torch.cat(all_exp, dim=0)

print(f"Shape of concatenated info tensor: {all_info_tensor.shape}")
print(f"Shape of concatenated pose tensor: {all_pose_tensor.shape}")
print(f"Shape of concatenated t tensor: {all_t_tensor.shape}")
print(f"Shape of concatenated scale tensor: {all_scale_tensor.shape}")
print(f"Shape of concatenated exp tensor: {all_exp_tensor.shape}")
print(f"Expected shape for info tensor: [{I_d_lst.shape[0]}, 69]")


In [None]:
print("Pose tensor - Mean:", all_pose_tensor.mean(), "Variance:", all_pose_tensor.var())
print("Translation tensor - Mean:", all_t_tensor.mean(), "Variance:", all_t_tensor.var())
print("Scale tensor - Mean:", all_scale_tensor.mean(), "Variance:", all_scale_tensor.var())
print("Expression tensor - Mean:", all_exp_tensor.mean(), "Variance:", all_exp_tensor.var())
print("Info tensor - Mean:", all_info_tensor.mean(), "Variance:", all_info_tensor.var())
print(all_pose_tensor.min(),
      all_pose_tensor.max())
print(all_t_tensor.min(),
      all_t_tensor.max())
print(all_scale_tensor.min(),
      all_scale_tensor.max())
print(all_exp_tensor.min(),
      all_exp_tensor.max())
print(all_info_tensor.min(),
      all_info_tensor.max())

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

def plot_distributions(tensors, titles):
    fig, axs = plt.subplots(len(tensors), 1, figsize=(10, 5*len(tensors)))
    for tensor, title, ax in zip(tensors, titles, axs):
        ax.hist(tensor.cpu().numpy().flatten(), bins=50)
        ax.set_title(title)
    plt.tight_layout()
    plt.show()

# Plot original distributions
plot_distributions([all_pose_tensor, all_t_tensor, all_scale_tensor, all_exp_tensor],
                   ['Pose', 'Translation', 'Scale', 'Expression'])

# Define normalization functions
def z_score_normalize(tensor):
    mean = tensor.mean()
    std = tensor.std()
    return (tensor - mean) / std

def min_max_normalize(tensor, feature_range=(-1, 1)):
    min_val = tensor.min()
    max_val = tensor.max()
    scale = (feature_range[1] - feature_range[0]) / (max_val - min_val)
    return (tensor - min_val) * scale + feature_range[0]

def robust_normalize(tensor):
    median = tensor.median()
    q1 = tensor.quantile(0.25)
    q3 = tensor.quantile(0.75)
    iqr = q3 - q1
    return (tensor - median) / iqr

# Apply different normalizations
z_score_tensors = [z_score_normalize(t) for t in [all_pose_tensor, all_t_tensor, all_scale_tensor, all_exp_tensor]]
min_max_tensors = [min_max_normalize(t) for t in [all_pose_tensor, all_t_tensor, all_scale_tensor, all_exp_tensor]]
robust_tensors = [robust_normalize(t) for t in [all_pose_tensor, all_t_tensor, all_scale_tensor, all_exp_tensor]]

# Plot normalized distributions
plot_distributions(z_score_tensors, ['Z-score Pose', 'Z-score Translation', 'Z-score Scale', 'Z-score Expression'])
plot_distributions(min_max_tensors, ['Min-Max Pose', 'Min-Max Translation', 'Min-Max Scale', 'Min-Max Expression'])
plot_distributions(robust_tensors, ['Robust Pose', 'Robust Translation', 'Robust Scale', 'Robust Expression'])


### Extract pipeline

In [None]:
n_frames = I_d_lst.shape[0]
template_dct = {
    'n_frames': n_frames,
    'output_fps': 25,
    'motion': [],
    'c_d_eyes_lst': [],
    'c_d_lip_lst': [],
    'x_i_info_lst': [],
}

for i in (range(n_frames)):
    # collect s, R, δ and t for inference
    I_i = I_d_lst[i]
    x_i_info = get_kp_info(I_i)
    R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])

    item_dct = {
        'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
        'R': R_i.cpu().numpy().astype(np.float32),
        'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
        't': x_i_info['t'].cpu().numpy().astype(np.float32),
    }

    template_dct['motion'].append(item_dct)

    # c_eyes = c_d_eyes_lst[i].astype(np.float32)
    # template_dct['c_d_eyes_lst'].append(c_eyes)

    # c_lip = c_d_lip_lst[i].astype(np.float32)
    # template_dct['c_d_lip_lst'].append(c_lip)

    template_dct['x_i_info_lst'].append(x_i_info)
    # print(f'frame {i} done')

#### Frontalize

In [None]:
import torch

def angular_distance(pose1, pose2):
    diff = torch.abs(pose1 - pose2)
    diff = torch.min(diff, 2*torch.pi - diff)
    return torch.norm(diff)

def find_dominant_pose(poses):
    N = poses.shape[0]
    total_distances = torch.zeros(N, device=poses.device)
    for i in range(N):
        distances = angular_distance(poses[i].unsqueeze(0), poses)
        total_distances[i] = torch.sum(distances)
    min_distance_index = torch.argmin(total_distances)
    return poses[min_distance_index], min_distance_index

# Prepare data
n_frames = len(template_dct['motion'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Collect all poses and t values
all_poses = torch.zeros(n_frames, 3, device=device)
all_t = torch.zeros(n_frames, 3, device=device)

for i in range(n_frames):
    info = template_dct['x_i_info_lst'][i]
    roll, pitch, yaw = info['roll'], info['pitch'], info['yaw']
    all_poses[i] = torch.tensor([roll, pitch, yaw], device=device).squeeze()
    all_t[i] = torch.tensor(template_dct['motion'][i]['t'], device=device)

# Find dominant pose
dominant_pose, _ = find_dominant_pose( )

# Find median t
median_t = torch.median(all_t, dim=0).values

# Subtract dominant pose and median t from the sequence
for i in range(n_frames):
    # Update pose
    template_dct['x_i_info_lst'][i]['roll'] = (all_poses[i, 0]  - 1 * dominant_pose[0]).unsqueeze(0)
    template_dct['x_i_info_lst'][i]['pitch'] = (all_poses[i, 1] - 1 * dominant_pose[1]).unsqueeze(0)
    template_dct['x_i_info_lst'][i]['yaw'] = (all_poses[i, 2]   - 1 * dominant_pose[2]).unsqueeze(0)

    # Update t
    template_dct['motion'][i]['t'] = (all_t[i] - median_t).cpu().numpy()

    # Recalculate R with the updated pose
    new_R = get_rotation_matrix(
        template_dct['x_i_info_lst'][i]['pitch'],
        template_dct['x_i_info_lst'][i]['yaw'],
        template_dct['x_i_info_lst'][i]['roll']
    )
    template_dct['motion'][i]['R'] = new_R.cpu().numpy()

print(f"Dominant pose (roll, pitch, yaw): {dominant_pose.cpu().numpy()}")
print(f"Median t: {median_t.cpu().numpy()}")