In [1]:
import time

import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

# Load model
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

Compile complete


In [12]:
import numpy as np


def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos(imgs, device) -> torch.Tensor:
    """ construct the input as standard
    imgs: NxBxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)[..., np.newaxis]  # TxHxWx3x1
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    y = _imgs.astype(np.float32) / 255.
    y = np.clip(y, 0, 1)  # clip to 0~1
    y = torch.from_numpy(y).permute(0, 4, 3, 1, 2)  # TxHxWx3x1 -> Tx1x3xHxW
    y = y.to(device)

    return y


In [13]:
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

In [14]:
def read_video_frames(video_path, frame_queue, num_threads=4):
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_per_thread = total_frames // num_threads

    def worker(start_frame, end_frame):
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
        for _ in range(start_frame, end_frame):
            ret, frame = cap.read()
            if not ret:
                break
            frame_queue.put((video_path, _, cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
        cap.release()

    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = []
        for i in range(num_threads):
            start = i * frames_per_thread
            end = start + frames_per_thread if i < num_threads - 1 else total_frames
            futures.append(executor.submit(worker, start, end))

        for future in as_completed(futures):
            future.result()

def read_multiple_videos(video_paths, num_threads_per_video=4):
    frame_queue = queue.Queue()

    def read_video(path):
        read_video_frames(path, frame_queue, num_threads_per_video)

    with ThreadPoolExecutor(max_workers=len(video_paths)) as executor:
        futures = [executor.submit(read_video, path) for path in video_paths]
        for future in as_completed(futures):
            future.result()

    frame_queue.put(None)  # Signal that all videos have been read
    return frame_queue

time_start = time.time()

# Get all .mp4 files in the specified directory
video_dir = '/mnt/e/data/vox2/0_500_512_video/id00062/ImB2zCgOuyk'
video_paths = glob.glob(os.path.join(video_dir, '*.mp4'))

print(f"Found {len(video_paths)} video files.")

frame_queue = read_multiple_videos(video_paths)

# Process frames as they become available
frames_list = []
total_frames = 0

while True:
    item = frame_queue.get()
    if item is None:
        break
    video_path, frame_index, frame = item
    frames_list.append(frame)
    total_frames += 1

    if total_frames % 1000 == 0:
        print(f"Processed {total_frames} frames so far...")

print("\nProcessing complete. Concatenating frames...")
process_complete_time = time.time()
print(f"Time taken to process all frames: {process_complete_time - time_start:.2f} seconds")

# Concatenate all frames into one large numpy array
all_frames = np.stack(frames_list, axis=0)

# Convert to tensor
# all_frames_tensor = torch.from_numpy(all_frames).permute(0, 3, 1, 2).float() / 255.0
# transform_to_tensor_time = time.time()
# print(f"Time taken to convert frames to tensor: {transform_to_tensor_time - process_complete_time:.2f} seconds")

# print(f"Total frames across all videos: {total_frames}")
# print(f"Shape of concatenated tensor: {all_frames_tensor.shape}")
# print(f"Dtype of tensor: {all_frames_tensor.dtype}")
# print(f"Memory usage of tensor: {all_frames_tensor.element_size() * all_frames_tensor.nelement() / (1024**3):.2f} GB")

Found 17 video files.
Processed 1000 frames so far...
Processed 2000 frames so far...
Processed 3000 frames so far...

Processing complete. Concatenating frames...
Time taken to process all frames: 1.46 seconds


In [6]:
def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

In [15]:
driving_rgb_lst = all_frames
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
I_d_lst = prepare_videos(driving_rgb_crop_256x256_lst, device)

In [16]:
n_frames = I_d_lst.shape[0]
template_dct = {
    'n_frames': n_frames,
    'output_fps': 25,
    'motion': [],
    'c_d_eyes_lst': [],
    'c_d_lip_lst': [],
    'x_i_info_lst': [],
}

for i in(range(n_frames)):
    # collect s, R, δ and t for inference
    I_i = I_d_lst[i]
    x_i_info = get_kp_info(I_i)
    R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])

    item_dct = {
        'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
        'R': R_i.cpu().numpy().astype(np.float32),
        'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
        't': x_i_info['t'].cpu().numpy().astype(np.float32),
    }

    template_dct['motion'].append(item_dct)

    template_dct['x_i_info_lst'].append(x_i_info)
    print(f'frame {i} done')

frame 0 done
frame 1 done
frame 2 done
frame 3 done
frame 4 done
frame 5 done
frame 6 done
frame 7 done
frame 8 done
frame 9 done
frame 10 done
frame 11 done
frame 12 done
frame 13 done
frame 14 done
frame 15 done
frame 16 done
frame 17 done
frame 18 done
frame 19 done
frame 20 done
frame 21 done
frame 22 done
frame 23 done
frame 24 done
frame 25 done
frame 26 done
frame 27 done
frame 28 done
frame 29 done
frame 30 done
frame 31 done
frame 32 done
frame 33 done
frame 34 done
frame 35 done
frame 36 done
frame 37 done
frame 38 done
frame 39 done
frame 40 done
frame 41 done
frame 42 done
frame 43 done
frame 44 done
frame 45 done
frame 46 done
frame 47 done
frame 48 done
frame 49 done
frame 50 done
frame 51 done
frame 52 done
frame 53 done
frame 54 done
frame 55 done
frame 56 done
frame 57 done
frame 58 done
frame 59 done
frame 60 done
frame 61 done
frame 62 done
frame 63 done
frame 64 done
frame 65 done
frame 66 done
frame 67 done
frame 68 done
frame 69 done
frame 70 done
frame 71 done
fr

In [17]:
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out

In [18]:
input_path = '/mnt/c/Users/mjh/Downloads/live_in/t4.jpg'
img_rgb = load_image_rgb(input_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)


In [20]:
# Initialize variables
frame_index = 0
total_frames = len(template_dct['motion'])

# Create a window for display
cv2.namedWindow('Processed Frame', cv2.WINDOW_NORMAL)
cv2.resizeWindow('Processed Frame', 512, 512)  # Adjust size as needed

while frame_index < total_frames:
    # Get motion data for the current frame
    R = template_dct['motion'][frame_index]['R']
    exp = template_dct['motion'][frame_index]['exp']
    t = template_dct['motion'][frame_index]['t']
    scale = template_dct['motion'][frame_index]['scale']

    # Convert to tensors
    scale_tensor = torch.tensor(scale, device=device)
    R_tensor = torch.tensor(R, device=device)
    exp_tensor = torch.tensor(exp, device=device)
    t_tensor = torch.tensor(t, device=device)

    # Process the frame
    x_d_i_new = scale_tensor * (x_c_s @ R_tensor + exp_tensor) + t_tensor
    out = warp_decode(f_s, x_s, x_d_i_new)

    # Convert tensor to numpy array and rescale to 0-255 range
    img_np = (out['out'][0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    # Convert from RGB to BGR for cv2
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

    # Display the frame
    cv2.imshow('Processed Frame', img_bgr)

    # Print progress
    print(f"Processed frame {frame_index+1}/{total_frames}")

    # Wait for a short time and check for 'q' key to quit
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    frame_index += 1

    # Optional: add a small delay to make the display more visible
    time.sleep(0.03)  # Adjust as needed

# Clean up
cv2.destroyAllWindows()

print("Processing complete.")