### Compile
and initialize args

In [1]:
import time

import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

Compile complete


### Initialize util functions

Initialize motion extraction pipeline

In [2]:
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth

In [3]:
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

In [4]:
import numpy as np


def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos(imgs, device) -> torch.Tensor:
    """ construct the input as standard
    imgs: NxBxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)[..., np.newaxis]  # TxHxWx3x1
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    y = _imgs.astype(np.float32) / 255.
    y = np.clip(y, 0, 1)  # clip to 0~1
    y = torch.from_numpy(y).permute(0, 4, 3, 1, 2)  # TxHxWx3x1 -> Tx1x3xHxW
    y = y.to(device)

    return y


### Single Video loading ( SLOW load)

In [21]:
vid_path = '/mnt/e/data/diffposetalk_data/TFHP_raw/crop/TH_00212/000.mp4'
is_video(vid_path)
driving_rgb_lst = load_video(vid_path)

In [22]:

# driving_rgb_lst = load_video(vid_path)
len(driving_rgb_lst), driving_rgb_lst[0].shape, type(driving_rgb_lst[0])
driving_lmk_crop_lst = cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]

c_d_eyes_lst, c_d_lip_lst = calc_ratio(driving_lmk_crop_lst)
I_d_lst = prepare_videos(driving_rgb_crop_256x256_lst, device)
# I_d_lst = I_d_lst.squeeze(1)

### Folder of videos load ( fast )

In [7]:
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio
from tqdm import tqdm

In [8]:
def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=16):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

def get_video_paths(directory, max_videos=100):
    video_paths = []
    for root, dirs, files in tqdm(os.walk(directory), desc="Searching for videos"):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):  # Add more video extensions if needed
                video_paths.append(os.path.join(root, file))
                if len(video_paths) >= max_videos:
                    return sorted(video_paths)
    return sorted(video_paths)

# Usage
# video_dir = '/mnt/e/data/vox2/0_500_512_video/id00062/ImB2zCgOuyk
video_dir = '/mnt/e/data/vox2/videos/512'
video_paths = get_video_paths(video_dir, max_videos=5)  # Limit to 100 videos

print(f"Found {len(video_paths)} video files.")


Searching for videos: 2it [00:00, 130.59it/s]

Found 5 video files.





In [9]:

video_frames = read_multiple_videos(video_paths)
all_frames = []
total_frames = 0
video_lengths = []

for video_path, frames in video_frames:
    all_frames.extend(frames)
    frame_count = len(frames)
    total_frames += frame_count
    video_lengths.append(frame_count)
    print(f"Processed video: {video_path}, frames: {frame_count}")

print(f"\nTotal frames across all videos: {total_frames}")
print(f"Video lengths: {video_lengths}")

# Convert to numpy array
all_frames = np.array(all_frames)

print(f"Shape of concatenated array: {all_frames.shape}")

Processed video: /mnt/e/data/vox2/videos/512/id00012/aE4Om0EEiuk/00116.mp4, frames: 187
Processed video: /mnt/e/data/vox2/videos/512/id00012/aE4Om0EEiuk/00117.mp4, frames: 172
Processed video: /mnt/e/data/vox2/videos/512/id00012/aE4Om0EEiuk/00120.mp4, frames: 412
Processed video: /mnt/e/data/vox2/videos/512/id00012/aE4Om0EEiuk/00122.mp4, frames: 300
Processed video: /mnt/e/data/vox2/videos/512/id00012/aE4Om0EEiuk/00123.mp4, frames: 252

Total frames across all videos: 1323
Video lengths: [187, 172, 412, 300, 252]
Shape of concatenated array: (1323, 256, 256, 3)


In [9]:
# import cv2
# import time

# # Create a window
# cv2.namedWindow('Video Frames', cv2.WINDOW_NORMAL)

# for i in range(208):
#     # Get the current frame
#     frame = all_frames[i]

#     # Display the frame number
#     cv2.putText(frame, f'Frame: {i+1}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

#     # Display the image
#     cv2.imshow('Video Frames', cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

#     # Wait for a key press or 33 ms (approx. 30 fps)
#     key = cv2.waitKey(33) & 0xFF

#     # If 'q' is pressed, break the loop
#     if key == ord('q'):
#         break

# # Close all OpenCV windows
# cv2.destroyAllWindows()


In [8]:
def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

In [None]:
driving_rgb_lst = I_d_lst
I_d_lst = prepare_videos_(driving_rgb_lst, device)
# I_d_lst = I_d_lst.unsqueeze(1)

print(f"Shape of driving video: {I_d_lst.shape}")

### Demo pipeline

Motion Extractor

In [13]:
def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        # kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        # kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

In [23]:
print(I_d_lst.shape)
if len(I_d_lst.shape) == 5:
    I_d_lst = I_d_lst.squeeze(1)
    print(I_d_lst.shape)

torch.Size([1003, 1, 3, 256, 256])
torch.Size([1003, 3, 256, 256])


In [25]:
batch_size = 512
total_frames = I_d_lst.shape[0]
x_d_list = []
x_i_info_list = []

for i in range(0, total_frames, batch_size):
    batch = I_d_lst[i:i+batch_size]
    x_i_info = get_kp_info(batch)

    concat_tensor = torch.cat([
        x_i_info['kp'], # 63
        x_i_info['exp'], # 63, .reshape(mini_batch_end - mini_batch_start, -1),
        x_i_info['t'], # 3
        x_i_info['pitch'], # 1
        x_i_info['yaw'], # 1
        x_i_info['roll'], # 1
        x_i_info['scale'], # 1
    ], dim=1)

    x_i_info_list.append(concat_tensor)
x_i_info_list = torch.cat(x_i_info_list, dim=0)

print(f'x_d_list: {len(x_i_info_list)}', x_i_info_list.shape)


x_d_list: 1003 torch.Size([1003, 133])


In [None]:
# Convert x_i_info_list to numpy array
x_d_np = x_i_info_list.cpu().numpy()

# Define the path to save the numpy array
save_path = 'ls /mnt/c/Users/mjh/Downloads/x_d_list.npy'  # Replace with your desired path

# Save the numpy array
np.save(save_path, x_d_np)

print(f"Saved x_d_list to {save_path}")
print(f"Shape of saved array: {x_d_np.shape}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Flatten the last two dimensions of x_d
x_d_flattened = x_d.reshape(x_d.shape[0], -1)

# Get the number of parameters
num_params = x_d_flattened.shape[1]

# Calculate grid dimensions
grid_rows = int(np.ceil(np.sqrt(num_params)))
grid_cols = int(np.ceil(num_params / grid_rows))

plt.figure(figsize=(20, 15))

for i in range(num_params):
    # Extract the current parameter
    param_values = x_d_flattened[:, i].cpu().numpy()

    # Create a subplot for each parameter
    plt.subplot(grid_rows, grid_cols, i + 1)

    # Plot the histogram
    plt.hist(param_values, bins=50, edgecolor='black', range=(-1, 1))

    # Set title and labels
    plt.title(f'Parameter {i + 1}')
    plt.xlabel('Value')
    plt.ylabel('Frequency')

    # Set fixed x-axis range
    plt.xlim(-1, 1)

    # Add some statistics to the plot
    plt.text(0.05, 0.95, f'Mean: {np.mean(param_values):.4f}\nStd: {np.std(param_values):.4f}',
             transform=plt.gca().transAxes, verticalalignment='top', fontsize=8)

# Adjust the layout and display the plot
plt.suptitle('Distribution of Parameters Across Batches', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# Print summary statistics
print("Summary Statistics:")
for i in range(num_params):
    param_values = x_d_flattened[:, i].cpu().numpy()
    print(f"Parameter {i+1}:")
    print(f"  Mean: {np.mean(param_values):.4f}")
    print(f"  Std Dev: {np.std(param_values):.4f}")
    print(f"  Min: {np.min(param_values):.4f}")
    print(f"  Max: {np.max(param_values):.4f}")
    print()


In [None]:
# bz = 250
# total_count = I_d_lst.shape[0]
# for i in range(0, total_count, bz):
#     I_d_part = I_d_lst[i:i+bz]
#     x_i_info = get_kp_info(I_d_part)
#     print(x_i_info['kp'].shape)

In [None]:
# x_i_info = get_kp_info(I_d_part)

In [None]:
I_d_lst[0]
x_i_info_0 = get_kp_info(I_d_lst[0])

x_i_info_0['latent'].shape
x_i_info_0_latent_np = x_i_info_0['latent'].squeeze(0).cpu().numpy()
x_i_info_0_latent_np.shape


In [None]:
n_frames = I_d_lst.shape[0]
template_dct = {
    'n_frames': n_frames,
    'output_fps': 25,
    'motion': [],
    'c_d_eyes_lst': [],
    'c_d_lip_lst': [],
    'x_i_info_lst': [],
    'latent_lst': []  # New list to store latent vectors
}

for i in range(n_frames):
    # collect s, R, δ and t for inference
    I_i = I_d_lst[i]
    x_i_info = get_kp_info(I_i)
    R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])

    item_dct = {
        'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
        'R': R_i.cpu().numpy().astype(np.float32),
        'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
        't': x_i_info['t'].cpu().numpy().astype(np.float32),
    }

    template_dct['motion'].append(item_dct)
    template_dct['x_i_info_lst'].append(x_i_info)

    # Extract and store the latent vector
    latent = x_i_info['latent'].squeeze(0).cpu().numpy()
    template_dct['latent_lst'].append(latent)

    print(f'frame {i} done')

# After the loop, you can verify the shape of the stored latent vectors
print(f"Number of latent vectors: {len(template_dct['latent_lst'])}")
print(f"Shape of each latent vector: {template_dct['latent_lst'][0].shape}")

In [None]:
import pandas as pd
lat_tensor = torch.tensor(template_dct['latent_lst'])

# Analyze latent vectors
latents_analysis = lat_tensor[:, :]  # Analyze first 100 dimensions
latent_dim = latents_analysis.shape[-1]

# Flatten the batch dimension
flattened_motion = latents_analysis.view(-1, latent_dim)

# Calculate basic statistics
mean = flattened_motion.mean(dim=0)
std = flattened_motion.std(dim=0)
min_vals = flattened_motion.min(dim=0).values
max_vals = flattened_motion.max(dim=0).values

# Calculate absolute magnitude statistics
abs_flattened_motion = flattened_motion.abs()
abs_mean = abs_flattened_motion.mean(dim=0)
abs_std = abs_flattened_motion.std(dim=0)
abs_min_vals = abs_flattened_motion.min(dim=0).values
abs_max_vals = abs_flattened_motion.max(dim=0).values

# Create a summary dataframe
summary = pd.DataFrame({
    'Mean': mean.numpy(),
    'Std': std.numpy(),
    'Min': min_vals.numpy(),
    'Max': max_vals.numpy(),
    'Abs Mean': abs_mean.numpy(),
    'Abs Std': abs_std.numpy(),
    'Abs Min': abs_min_vals.numpy(),
    'Abs Max': abs_max_vals.numpy()
})

# Set display options to show all rows and columns
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Print the full summary
print(summary.to_string())

In [None]:
import numpy as np
import torch

# Assuming abs_std is already defined
# abs_std = torch.randn(70)  # Example tensor, replace with your actual tensor

# Define the bins in logarithmic scale from 1e-1 to 1e-8
bins = np.logspace(-1, -8, num=8)

# Digitize the abs_std values into the defined bins
bin_indices = np.digitize(abs_mean.numpy(), bins)

# Create a list to store indices for each bin
bin_index_lists = [[] for _ in range(len(bins) + 1)]

# Populate the bin_index_lists
for idx, bin_idx in enumerate(bin_indices):
    bin_index_lists[bin_idx].append(idx)

# Print the bin counts and indices
bin_counts = np.bincount(bin_indices, minlength=len(bins) + 1)
for i, (count, indices) in enumerate(zip(bin_counts, bin_index_lists)):
    if i == 0:
        print(f"< {bins[0]:.1e}: {count}")
    elif i == len(bins):
        print(f">= {bins[-1]:.1e}: {count}")
    else:
        print(f"{bins[i-1]:.1e} - {bins[i]:.1e}: {count}")
    print(f"Indices: {indices}")
    print()

In [None]:
latent_dim = 100

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# # Plot histograms for each feature
# fig, axes = plt.subplots( 10, latent_dim // 10 + int(latent_dim % 10!=0), figsize=(20, 14))
# axes = axes.flatten()

# from tqdm import tqdm
# for i in tqdm(range(latent_dim), desc="Plotting histograms"):
#     sns.histplot(flattened_motion[:, i].numpy(), ax=axes[i], kde=True)
#     axes[i].set_title(f'Feature {i}')
#     axes[i].set_xlabel('')
#     axes[i].set_ylabel('')

# plt.tight_layout()
# plt.show()

# Plot box plots for each feature
plt.figure(figsize=(20, 6))
sns.boxplot(data=flattened_motion[:, :latent_dim].numpy())
plt.title('Box Plot of Motion Features')
plt.xlabel('Feature Index')
plt.ylabel('Value')
plt.show()

In [None]:
n_frames = I_d_lst.shape[0]
for i in (range(n_frames)):
    # collect s, R, δ and t for inference
    I_i = I_d_lst[i]
    x_i_info = get_kp_info(I_i)

In [None]:
n_frames = I_d_lst.shape[0]
template_dct = {
    'n_frames': n_frames,
    'output_fps': 25,
    'motion': [],
    'c_d_eyes_lst': [],
    'c_d_lip_lst': [],
    'x_i_info_lst': [],
}

for i in (range(n_frames)):
    # collect s, R, δ and t for inference
    I_i = I_d_lst[i]
    x_i_info = get_kp_info(I_i)
    R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])

    item_dct = {
        'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
        'R': R_i.cpu().numpy().astype(np.float32),
        'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
        't': x_i_info['t'].cpu().numpy().astype(np.float32),
    }

    template_dct['motion'].append(item_dct)

    # c_eyes = c_d_eyes_lst[i].astype(np.float32)
    # template_dct['c_d_eyes_lst'].append(c_eyes)

    # c_lip = c_d_lip_lst[i].astype(np.float32)
    # template_dct['c_d_lip_lst'].append(c_lip)

    template_dct['x_i_info_lst'].append(x_i_info)
    print(f'frame {i} done')

#### Frontalize


In [None]:
# R = template_dct['motion'][0]['R']
# exp = template_dct['motion'][0]['exp']
# t = template_dct['motion'][0]['t']
# scale = template_dct['motion'][0]['scale']
# # print dims
# print(R.shape, exp.shape, t.shape, scale.shape)
# # print flatten dims
# print(R.flatten().shape, exp.flatten().shape, t.flatten().shape, scale.flatten().shape)
# # print range
# print(R.min(), R.max(), exp.min(), exp.max(), t.min(), t.max(), scale.min(), scale.max())

In [None]:
# for i in range(n_frames):
#     R = template_dct['motion'][i]['R']
#     exp = template_dct['motion'][i]['exp']
#     t = template_dct['motion'][i]['t']
#     scale = template_dct['motion'][i]['scale']
#     info = template_dct['x_i_info_lst'][i]
#     roll, pitch, yaw = info['roll'], info['pitch'], info['yaw']

#     new_R = get_rotation_matrix(pitch, yaw, roll)

In [None]:
import torch

def angular_distance(pose1, pose2):
    diff = torch.abs(pose1 - pose2)
    diff = torch.min(diff, 2*torch.pi - diff)
    return torch.norm(diff)

def find_dominant_pose(poses):
    N = poses.shape[0]
    total_distances = torch.zeros(N, device=poses.device)
    for i in range(N):
        distances = angular_distance(poses[i].unsqueeze(0), poses)
        total_distances[i] = torch.sum(distances)
    min_distance_index = torch.argmin(total_distances)
    return poses[min_distance_index], min_distance_index

# Prepare data
n_frames = len(template_dct['motion'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Collect all poses and t values
all_poses = torch.zeros(n_frames, 3, device=device)
all_t = torch.zeros(n_frames, 3, device=device)

for i in range(n_frames):
    info = template_dct['x_i_info_lst'][i]
    roll, pitch, yaw = info['roll'], info['pitch'], info['yaw']
    all_poses[i] = torch.tensor([roll, pitch, yaw], device=device).squeeze()
    all_t[i] = torch.tensor(template_dct['motion'][i]['t'], device=device)

# Find dominant pose
dominant_pose, _ = find_dominant_pose(all_poses)

# Find median t
median_t = torch.median(all_t, dim=0).values

# Subtract dominant pose and median t from the sequence
for i in range(n_frames):
    # Update pose
    template_dct['x_i_info_lst'][i]['roll'] = (all_poses[i, 0]  - 1 * dominant_pose[0]).unsqueeze(0)
    template_dct['x_i_info_lst'][i]['pitch'] = (all_poses[i, 1] - 1 * dominant_pose[1]).unsqueeze(0)
    template_dct['x_i_info_lst'][i]['yaw'] = (all_poses[i, 2]   - 1 * dominant_pose[2]).unsqueeze(0)

    # Update t
    template_dct['motion'][i]['t'] = (all_t[i] - median_t).cpu().numpy()

    # Recalculate R with the updated pose
    new_R = get_rotation_matrix(
        template_dct['x_i_info_lst'][i]['pitch'],
        template_dct['x_i_info_lst'][i]['yaw'],
        template_dct['x_i_info_lst'][i]['roll']
    )
    template_dct['motion'][i]['R'] = new_R.cpu().numpy()

print(f"Dominant pose (roll, pitch, yaw): {dominant_pose.cpu().numpy()}")
print(f"Median t: {median_t.cpu().numpy()}")

### Generator Speed test

Initialize useful functions

In [None]:
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()



def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out

Initialize source image

In [None]:
input_path = '/mnt/c/Users/mjh/Downloads/live_in/t4.jpg'
img_rgb = load_image_rgb(input_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)


Single frame retarget

In [None]:
# R = template_dct['motion'][0]['R']
# exp = template_dct['motion'][0]['exp']
# t = template_dct['motion'][0]['t']
# scale = template_dct['motion'][0]['scale']

# scale_tensor = torch.tensor(scale, device=device)
# R_tensor = torch.tensor(R, device=device)
# exp_tensor = torch.tensor(exp, device=device)
# t_tensor = torch.tensor(t, device=device)
# print(scale_tensor.shape, R_tensor.shape, exp_tensor.shape, t_tensor.shape)

# start = time.time()
# x_d_i_new = scale_tensor * (x_c_s @ R_tensor + exp_tensor) + t_tensor

# # x_d_i_new = scale * (x_c_s @ R + exp) + t
# out = warp_decode(f_s, x_s, x_d_i_new)
# # print(out)
# # I_p_i = parse_output(out['out'])[0]
# end_time = time.time() - start
# print(f'warp_decode time: {end_time}')

Large chunk of frames generator. Performance testing

In [None]:
import cv2
import time

# Initialize variables
frame_index = 0
total_frames = len(template_dct['motion'])

# Create a window for display
# cv2.namedWindow('Processed Frame', cv2.WINDOW_NORMAL)
# cv2.resizeWindow('Processed Frame', 512, 512)  # Adjust size as needed

In [None]:
import cv2
import time

# Initialize variables
frame_index = 0
total_frames = len(template_dct['motion'])

# Create a window for display
cv2.namedWindow('Processed Frame', cv2.WINDOW_NORMAL)
cv2.resizeWindow('Processed Frame', 512, 512)  # Adjust size as needed

while frame_index < total_frames:
    # Get motion data for the current frame
    R = template_dct['motion'][frame_index]['R']
    exp = template_dct['motion'][frame_index]['exp']
    t = template_dct['motion'][frame_index]['t']
    scale = template_dct['motion'][frame_index]['scale']

    # Convert to tensors
    scale_tensor = torch.tensor(scale, device=device)
    R_tensor = torch.tensor(R, device=device)
    exp_tensor = torch.tensor(exp, device=device)
    t_tensor = torch.tensor(t, device=device)

    # Process the frame
    x_d_i_new = scale_tensor * (x_c_s @ R_tensor + exp_tensor) + t_tensor
    out = warp_decode(f_s, x_s, x_d_i_new)

    # Convert tensor to numpy array and rescale to 0-255 range
    img_np = (out['out'][0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    # Convert from RGB to BGR for cv2
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

    # Display the frame
    cv2.imshow('Processed Frame', img_bgr)

    # Print progress
    print(f"Processed frame {frame_index+1}/{total_frames}")

    # Wait for a short time and check for 'q' key to quit
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    frame_index += 1

    # Optional: add a small delay to make the display more visible
    time.sleep(0.01)  # Adjust as needed

# Clean up
cv2.destroyAllWindows()

print("Processing complete.")