### Full pipeline inference
1. load target image to drive
2. load audio
3. inference dit with audio to get latent
4. drive target with latent

### 1. Load target image

1.1 Prepare full live pipeline

In [None]:
import time

import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

'''
Common modules
'''

from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth


'''
Util functions
'''

def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

def euler_to_quaternion(pitch, yaw, roll):
    cy = torch.cos(yaw * 0.5)
    sy = torch.sin(yaw * 0.5)
    cp = torch.cos(pitch * 0.5)
    sp = torch.sin(pitch * 0.5)
    cr = torch.cos(roll * 0.5)
    sr = torch.sin(roll * 0.5)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy

    return torch.stack([w, x, y, z], dim=-1)

def quaternion_to_euler(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in radians.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor
    """
    # Extract the values from q
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    pitch = torch.where(
        torch.abs(sinp) >= 1,
        torch.sign(sinp) * torch.pi / 2,
        torch.asin(sinp)
    )

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return pitch, yaw, roll

def quaternion_to_euler_degrees(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in degrees.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor in degrees
    """
    pitch, yaw, roll = quaternion_to_euler(q)
    return torch.rad2deg(pitch), torch.rad2deg(yaw), torch.rad2deg(roll)


'''
Loading source related modules
'''
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out
'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)


1.2 Load single image

In [None]:
input_path = '/mnt/c/Users/mjh/Downloads/live_in/t5.jpg'
img_rgb = load_image_rgb(input_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)


### 2 Load audio

In [None]:
use_custom_audio_data = True

2.1 Prepare audio pipeline

In [None]:
import json
import torch
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import os
import numpy as np
from typing import List, Tuple
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Constants
MODEL_NAME = "facebook/wav2vec2-base-960h"
TARGET_SAMPLE_RATE = 16000
FRAME_RATE = 25
SECTION_LENGTH = 3
OVERLAP = 10

DB_ROOT = 'vox2-audio-tx'
LOG = 'log'
AUDIO = 'audio/audio'
OUTPUT_DIR = 'audio_encoder_output'
BATCH_SIZE = 16


# Move model and processor to global scope
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

def read_multiple_audios(paths, num_threads=12):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(load_and_process_audio, paths))
    return results


def read_json_and_form_paths(data,id_key):
    filenames=[]
    file_paths = []

    # Iterate through the nested structure
    for id_key, id_value in data.items():
        os.makedirs(os.path.join(DB_ROOT,OUTPUT_DIR,id_key), exist_ok=True)
        for url_key, url_value in id_value.items():
            for clip_id in url_value.keys():
                # Form the file path
                file_path = os.path.join(DB_ROOT,AUDIO,id_key, url_key, clip_id.replace('.txt', '.wav'))
                file_name = os.path.join(DB_ROOT,OUTPUT_DIR,id_key, url_key+'+'+clip_id.replace('.txt', ''))
                filenames.append(file_name)
                file_paths.append(file_path)

    return file_paths, filenames

def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)

    original_sample_rate = sample_rate

    if sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    print(file_path," waveform.shape ",waveform.shape)

    # Calculate section length and overlap in samples
    section_samples = SECTION_LENGTH * 16027
    overlap_samples = int(OVERLAP / FRAME_RATE * TARGET_SAMPLE_RATE)
    print('section_samples',section_samples,'overlap_samples',overlap_samples)

    # Pad if shorter than 3 seconds
    if waveform.shape[1] < section_samples:
        waveform = torch.nn.functional.pad(waveform, (0, section_samples - waveform.shape[1]))
        return [waveform.squeeze(0)], original_sample_rate

    # Split into sections with overlap
    sections = []
    start = 0

    print('starting to segment', file_path)
    while start < waveform.shape[1]:
        end = start + section_samples
        if end >= waveform.shape[1]:
            tmp=waveform[:, start:min(end, waveform.shape[1])]
            tmp = torch.nn.functional.pad(tmp, (0, section_samples - tmp.shape[1]))
            sections.append(tmp.squeeze(0))
            print(tmp.shape)
            break
        else:
            sections.append(waveform[:, start:min(end, waveform.shape[1])].squeeze(0))

        start = int(end - overlap_samples)


    return file_path, sections

def inference_process_wav_file(path):
    audio_path, segments = load_and_process_audio(path)
    print(audio_path,segments)
    segments = np.array(segments)

    inputs = wav2vec_processor(segments, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated


def process_wav_file(paths, output_paths, uid):
    device = torch.device(f"cuda")

    model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
    processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

    read_2_gpu_batch_size = 2048
    gpu_batch_size = BATCH_SIZE
    process_queue = torch.Tensor().to(device)

    audio_segments = read_multiple_audios(paths, num_threads=4)
    all_segments = []
    total_segments = 0

    audio_lengths = []
    output_fns = []

    for (audio_path, segments), output_fn in zip(audio_segments,output_paths):
        all_segments.extend(segments)
        segment_count = len(segments)
        total_segments += segment_count
        audio_lengths.append(segment_count)

        output_fns.append(output_fn)

    all_segments = np.array(all_segments)
    print(all_segments.size)

    read_data_2_gpu_pointer = 0
    pbar = tqdm(total=total_segments, desc=f"Processing {uid}")

    while read_data_2_gpu_pointer < total_segments:
        current_batch_size = min(read_2_gpu_batch_size, total_segments - read_data_2_gpu_pointer)

        batch_input = all_segments[read_data_2_gpu_pointer:read_data_2_gpu_pointer + current_batch_size]

        mini_batch_start = 0
        all_info = []
        while mini_batch_start < batch_input.shape[0]:
            mini_batch_end = min(mini_batch_start + gpu_batch_size, batch_input.shape[0])
            mini_batch = batch_input[mini_batch_start:mini_batch_end]

            inputs = processor(mini_batch, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

            with torch.no_grad():
                outputs = model(inputs)

            latent = outputs.last_hidden_state
            print('latent',latent.shape)
            seq_length = latent.shape[1]
            new_seq_length = int(seq_length * (FRAME_RATE / 50))  # Assuming Wav2Vec2 outputs at ~50Hz

            latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                            size=new_seq_length,
                                                            mode='linear',
                                                            align_corners=True).transpose(1,2)
            print('latent_features_interpolated',latent_features_interpolated.shape)
            all_info.append(latent_features_interpolated)

            mini_batch_start = mini_batch_end
        all_info_tensor = torch.cat(all_info, dim=0)

        process_queue = torch.cat((process_queue, all_info_tensor), dim=0)

        print(audio_lengths)
        while len(output_fns) > 0 and len(process_queue) >= audio_lengths[0]:
            current_output_fn = output_fns[0]
            current_segment_count = audio_lengths[0]

            audio_tensor = process_queue[:current_segment_count]
            np.save(current_output_fn, audio_tensor.cpu().numpy())
            print('save',current_output_fn)
            process_queue = process_queue[current_segment_count:]
            output_fns.pop(0)
            audio_lengths.pop(0)


        read_data_2_gpu_pointer += current_batch_size
        pbar.update(current_batch_size)

    pbar.close()


2.2 Load custom audio

In [None]:
custom_audio_path = '/mnt/c/Users/mjh/Downloads/live_in/i8.wav'
custom_audio_latent = inference_process_wav_file(custom_audio_path)

In [None]:
custom_audio_latent.shape

2.3 Dataset audio

In [None]:
input_dict = {
    "motion":[
        '/mnt/e/data/live_latent/motion_latent/id00078/P0OU4bFhwCI+00227.npy',
        '/mnt/e/data/live_latent/motion_latent/id00019/anX4gftNLoc+00140.npy',
        '/mnt/e/data/live_latent/motion_latent/id00012/aE4Om0EEiuk+00117.npy',
        '/mnt/e/data/live_latent/motion_latent/id09263/ARVWnF_NcCI+00002.npy',
        # HDTF
        '/mnt/e/data/diffposetalk_data/TFHP_raw/test_split/live_latent/TH_00314/006.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/test_split/live_latent/TH_00226/001.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00287/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00192/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00028/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00224/000.npy',
    ],
    "audio":[
        '/mnt/e/data/live_latent/audio_latent/id00078/P0OU4bFhwCI+00227.npy',
        '/mnt/e/data/live_latent/audio_latent/id00019/anX4gftNLoc+00140.npy',
        '/mnt/e/data/live_latent/audio_latent/id00012/aE4Om0EEiuk+00117.npy',
        '/mnt/e/data/live_latent/audio_latent/id09263/ARVWnF_NcCI+00002.npy',
        # HDTF
        '/mnt/e/data/diffposetalk_data/TFHP_raw/test_split/audio_latent/TH_00314/006.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/test_split/audio_latent/TH_00226/001.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/audio_latent/TH_00287/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/audio_latent/TH_00192/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/audio_latent/TH_00028/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/audio_latent/TH_00224/000.npy',
    ],
    "wav":[
        '/mnt/c/Users/mjh/Downloads/audio_id00078_P0OU4bFhwCI_00227.wav',
        '/mnt/c/Users/mjh/Downloads/audio_id00019_anX4gftNLoc_00140.wav',
        '/mnt/c/Users/mjh/Downloads/audio_id00012_aE4Om0EEiuk_00117.wav',
        '/mnt/c/Users/mjh/Downloads/audio_id09263_ARVWnF_NcCI_00002.wav',
        # HDTF
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00314/006.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00226/001.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00287/000.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00192/000.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00028/000.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00224/000.wav',
    ]
}

In [None]:
import numpy as np
index = 6
motion_latent_path = input_dict['motion'][index]
dataset_audio_latent_path = input_dict['audio'][index]
dataset_audio_wav_path = input_dict['wav'][index]
# audio_latent, motion_latent_input = load_and_process_pair(audio_latent_path, motion_latent_npy)

if not use_custom_audio_data:
    dataset_audio_latent = inference_process_wav_file(dataset_audio_wav_path)

### 3 Load DiT model

3.0 decide DiT type

In [None]:
from audio_dit.inference import InferenceManager, get_model
from audio_dit.dataset import load_and_process_pair
config_path = 'audio_dit/output/config.json'
weight_path = 'audio_dit/output/model_sterotype_1_140.pth'
print("config_path exists:", os.path.exists(config_path))
audio_model_config = json.load(open(config_path))
inference_manager = get_model(config_path, weight_path, device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

total_params = count_parameters(inference_manager.model)
trainable_params = sum(p.numel() for p in inference_manager.model.parameters() if p.requires_grad)

print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')


total_params = count_parameters(warping_module)
print(f'Warping module parameters: {total_params:,}')
# decode
total_params = count_parameters(spade_generator)
print(f'Spade generator parameters: {total_params:,}')


In [None]:
if use_custom_audio_data:
    used_audio_example = custom_audio_path
    audio_latent = custom_audio_latent
else:
    used_audio_example = dataset_audio_wav_path
    audio_latent = dataset_audio_latent

audio_latent_input = audio_latent
latent_mask_used=audio_model_config['latent_mask_1']
latent_bound = audio_model_config['latent_bound']
print(latent_mask_used, latent_bound)
# latent_mask_used = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
motion_latent_processed, audio_latent_input, shape_in, mouth_ratio, _ = load_and_process_pair( \
                        audio_latent_input, motion_latent_path, 'exp',
                        latent_mask_1=latent_mask_used, latent_bound=latent_bound)
audio_seq = audio_latent_input[:, 10:, :]
audio_prev = audio_latent_input[:, :10, :]
motion_prev = motion_latent_processed[:, :10, :]
motion_gt = motion_latent_processed[:, 10:, :]

# Print shapes of input tensors
print("Audio input shape:", audio_seq.shape)
print("Audio previous shape:", audio_prev.shape)
print("Shape input shape:", shape_in.shape)
print("Motion previous shape:", motion_prev.shape)
print("Motion ground truth shape:", motion_gt.shape)
print("Shape input shape:", shape_in.shape)
print("Mouth ratio:", mouth_ratio.shape)

motion_gt = motion_gt.to(device)
motion_prev = motion_prev.to(device)
if use_custom_audio_data:
    motion_prev = torch.zeros(audio_latent.shape[0], 10, 6, device=device)


3.1 Prepare DiT Model

In [None]:
shape_in.shape, x_c_s.shape, mouth_ratio


In [None]:
subtract_avg_motion = True

In [None]:
mouth_open_ratio_val = 0.1337
mouth_open_ratio_input = torch.tensor([mouth_open_ratio_val], device=device).unsqueeze(0)
out_motion = torch.tensor([], device=device)
B, T, audio_dim = audio_seq.shape
motion_dim = audio_model_config['x_dim']
shape_in = x_c_s.reshape(1, -1).to(device)
# shape_in = shape_in[0:1]
this_audio_prev = torch.zeros(1, 10, audio_dim, device=device)
this_motion_prev = torch.zeros(1, 10, motion_dim , device=device)
init_motion_prev = this_motion_prev
if not use_custom_audio_data:
    this_audio_prev = audio_prev[0:1]
    this_motion_prev = motion_prev[0:1]
# this_audio_prev = None
# this_motion_prev = None
print("Audio input shape:", audio_seq.shape)
# print("Audio previous shape:", this_audio_prev.shape)
# print("Motion previous shape:", this_motion_prev.shape)
for batch_index in range(0, audio_seq.shape[0]):
    print(f'batch_index: {batch_index}')
    # this_motion_prev = motion_prev[batch_index:batch_index+1]
    # this_audio_prev = audio_prev[batch_index:batch_index+1]
    this_audio_prev = audio_prev[batch_index:batch_index+1]
    generated_motion = inference_manager.inference(audio_seq[batch_index:batch_index+1],
                                                shape_in, this_motion_prev, this_audio_prev, #seq_mask=seq_mask,
                                                cfg_scale=0.,
                                                mouth_open_ratio = mouth_open_ratio_input,
                                                denoising_steps=1)
    this_motion_prev = generated_motion[:, -10:, :]
    # generated_motion = generated_motion.reshape(-1, 6)

    if subtract_avg_motion:
        generated_motion = generated_motion - torch.mean(generated_motion, dim=-1, keepdim=True)
    out_motion = torch.cat((out_motion, generated_motion), dim=0)


In [None]:
generated_motion = out_motion

In [None]:
generated_motion = generated_motion.reshape(-1, 20)
motion_gt = motion_gt.reshape(-1, 20)

3.2 check shape

In [None]:
loss = F.mse_loss(generated_motion, motion_gt)
print(f"MSE Loss between generated motion and ground truth: {loss.item():.3e}")

In [None]:
# avg_diff_per_feat = 0
# for i in range(motion_gt.shape[0]):
#     all_diff = [f"{diff.item():.3e}" for diff in (generated_motion[0, i] - motion_gt[0, i])]
#     diff_sq = [ float(i) * float(i) for i in all_diff]
#     all_diff_sq = [f"{diff:.3e}" for diff in diff_sq]
#     # print(f"Frame {i}")
#     # print("Generated motion:", generated_motion[0, i])
#     # print(f"Ground truth motion: {motion_gt[0, i]}")
#     # print(f"L1 Difference: {all_diff}")
#     # print(f"L2 Difference: {all_diff_sq}")
#     local_gen_flatten = generated_motion[i, :, :].reshape(-1, generated_motion.shape[-1])
#     local_gt_flatten = motion_gt[i, :, :].reshape(-1, motion_gt.shape[-1])
#     print(f"Per segment MSE loss: {F.mse_loss(local_gen_flatten, local_gt_flatten).item():.3e}")
#     avg_diff_per_feat += abs(local_gen_flatten - local_gt_flatten)
#     # print()
# avg_diff_per_feat /= motion_gt.shape[0]
# avg_diff_per_feat_ = [f"{diff.item():.3e}" for diff in avg_diff_per_feat]
# print("Average difference per feature:", avg_diff_per_feat_)


In [None]:
avg_diff_per_feat = 0
gen_flatten = generated_motion.reshape(-1, generated_motion.shape[-1])
gt_flatten = motion_gt.reshape(-1, motion_gt.shape[-1])
print(gen_flatten.shape, gt_flatten.shape)
for i in range(gen_flatten.shape[0]):
    print(i)
    all_diff = [f"{diff.item():.3e}" for diff in (gen_flatten[i] - gt_flatten[i])]
    diff_sq = [ float(i) * float(i) for i in all_diff]
    all_diff_sq = [f"{diff:.3e}" for diff in diff_sq]
    # print(f"Frame {i}")
    # print("Generated motion:", generated_motion[0, i])
    # print(f"Ground truth motion: {motion_gt[0, i]}")
    print(f"L1 Difference: {all_diff}")
    # print(f"L2 Difference: {all_diff_sq}")
    print(f"Per frame MSE loss: {F.mse_loss(gen_flatten[i], gt_flatten[i]).item():.3e}")
    avg_diff_per_feat += abs(gen_flatten[i] - gt_flatten[i])
    # print()
print(avg_diff_per_feat.shape)
avg_diff_per_feat /= motion_gt.shape[0]
avg_diff_per_feat_ = [f"{diff.item():.3e}" for diff in avg_diff_per_feat]
print("Average difference per feature:", avg_diff_per_feat_)


In [None]:
plot_diff = False
if plot_diff:
    # Reshape tensors to combine batch and sequence dimensions
    generated_motion_flat = generated_motion.reshape(-1, generated_motion.shape[-1])  # [B*T, feat]
    motion_gt_flat = motion_gt.reshape(-1, motion_gt.shape[-1])  # [B*T, feat]
    start_frame_to_render = 200
    max_frame_to_render = 200
    # Store frame data and calculate global min/max values
    frame_data = []
    global_min_latent = float('inf')
    global_max_latent = float('-inf')
    global_min_diff = float('inf')
    global_max_diff = float('-inf')

    total_frames = generated_motion_flat.shape[0]
    for i in tqdm(range(start_frame_to_render, start_frame_to_render + max_frame_to_render)):
        frame_diff = {
            'frame_idx': i,
            'generated': generated_motion_flat[i].detach().cpu().numpy(),
            'ground_truth': motion_gt_flat[i].detach().cpu().numpy(),
            'difference': (generated_motion_flat[i] - motion_gt_flat[i]).detach().cpu().numpy(),
            'mse_loss': F.mse_loss(generated_motion_flat[i], motion_gt_flat[i]).item()
        }

        # Update global min/max values
        global_min_latent = min(global_min_latent,
                            frame_diff['generated'].min(),
                            frame_diff['ground_truth'].min())
        global_max_latent = max(global_max_latent,
                            frame_diff['generated'].max(),
                            frame_diff['ground_truth'].max())
        global_min_diff = min(global_min_diff, frame_diff['difference'].min())
        global_max_diff = max(global_max_diff, frame_diff['difference'].max())

        frame_data.append(frame_diff)

        # Print statistics for each frame
        print(f"Frame {i} (Batch {i // generated_motion.shape[1]}, Seq {i % generated_motion.shape[1]})")
        print("Generated motion:", frame_diff['generated'])
        print(f"Ground truth motion: {frame_diff['ground_truth']}")
        print(f"L1 Difference: {[f'{d:.3e}' for d in frame_diff['difference']]}")
        print(f"MSE loss: {frame_diff['mse_loss']:.3e}")
        print()

    # Calculate and print average difference across all frames
    avg_diff_per_feat = sum(abs(frame['difference']) for frame in frame_data) / len(frame_data)
    print("Average difference per feature:", [f"{diff:.3e}" for diff in avg_diff_per_feat])

    # Add padding to the limits for better visualization
    padding = 0.1
    latent_range = global_max_latent - global_min_latent
    diff_range = global_max_diff - global_min_diff
    global_min_latent -= latent_range * padding
    global_max_latent += latent_range * padding
    global_min_diff -= diff_range * padding
    global_max_diff += diff_range * padding

    # Create static plot showing differences across all frames
    plt.figure(figsize=(15, 8))
    frame_indices = range(len(frame_data))
    mse_losses = [frame['mse_loss'] for frame in frame_data]
    plt.plot(frame_indices, mse_losses, label='MSE Loss')
    plt.xlabel('Frame Index')
    plt.ylabel('MSE Loss')
    plt.title('MSE Loss per Frame (All Batches)')
    plt.legend()
    plt.grid(True)
    plt.savefig('output/mse_loss_per_frame.png')
    plt.close()

    # Create animated visualization with fixed axes
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
    fig.suptitle('Motion Latent Comparison (All Batches)')

    def update(frame):
        ax1.clear()
        ax2.clear()

        # Plot generated vs ground truth with fixed y-axis
        latent_indices = range(len(frame_data[frame]['generated']))
        ax1.plot(latent_indices, frame_data[frame]['generated'], label='Generated', alpha=0.7)
        ax1.plot(latent_indices, frame_data[frame]['ground_truth'], label='Ground Truth', alpha=0.7)
        ax1.set_title(f'Frame {frame} (Batch {frame // generated_motion.shape[1]}, Seq {frame % generated_motion.shape[1]}): Generated vs Ground Truth')
        ax1.set_xlabel('Latent Index')
        ax1.set_ylabel('Value')
        ax1.set_ylim(global_min_latent, global_max_latent)
        ax1.legend()
        ax1.grid(True)

        # Plot differences with fixed y-axis
        ax2.bar(latent_indices, frame_data[frame]['difference'])
        ax2.set_title(f'Frame {frame}: Differences (MSE: {frame_data[frame]["mse_loss"]:.3e})')
        ax2.set_xlabel('Latent Index')
        ax2.set_ylabel('Difference')
        ax2.set_ylim(global_min_diff, global_max_diff)
        ax2.grid(True)

        plt.tight_layout()
        return ax1, ax2

    # Create animation
    anim = FuncAnimation(fig, update, frames=len(frame_data), interval=40)

    # Save animation with higher quality
    writer = FFMpegWriter(fps=10, metadata=dict(artist='Me'), bitrate=1800)
    anim.save('output/latent_comparison.mp4', writer=writer)
    plt.close()

In [None]:
# generated_motion[:, :, -5] = motion_gt[:, :, -5]
# generated_motion[:, :, -2] = motion_gt[:, :, -2]


### 4. Do Render

In [None]:
print(f_s.shape, x_s.shape, x_c_s.shape)
print("Input scale is ", x_s_info['scale'])

In [None]:
# gen_pose = generated_motion[0:1,:, -5:]
# abs_gen_pose = torch.abs(gen_pose)
# # Calculate and print the average of each dimension in gen_pose
# avg_gen_pose = abs_gen_pose.mean(dim=(0, 1))
# print("Average of each dimension in gen_pose:")
# for i, avg in enumerate(avg_gen_pose):
#     print(f"Dimension {i}: {avg.item():.4f}")


In [None]:
x_s.shape

In [None]:
import cv2
import numpy as np
import torch

# Assuming generated_motion is your output from dit_inference
# generated_motion shape: [2, 65, 63]
# motion_prev shape: [2, 10, 63]

def process_motion_batch(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    frames = []
    B, feat_count = gen_motion_batch.shape
    full_motion = gen_motion_batch.reshape(B, feat_count)
    full_motion = torch.cat([motion_prev[0], full_motion], dim=0)

    pose = full_motion[:, -5:]
    exp = full_motion[:, :]
    print("pose", pose.shape, "exp", exp.shape)
    # exp_mask = torch.zeros_like(full_motion[0][0])
    # pos_mouth = [14, 17, 19, 20]
    # eye and mouth [15, 16, 18]
    # pos_eye & forehead = [1, 2, 11, 12, 13] # 1.z is mouth, 12 small eye
    # shape [0, 3, 4, 7, 8, 9, 10] # may include shape dependent pose
    # pos_cloth = [5, 6]

    # for p in pos_mouth:
    #     exp_mask[p * 3:(p + 1) * 3] = 1

    # # exp_mask = exp_mask.reshape(21,3)
    # print(full_motion[0, 10:15, :] * exp_mask)

    t_identity = torch.zeros((1, 3), dtype=torch.float32, device=device)
    pitch_identity = torch.zeros((1), dtype=torch.float32, device=device)
    yaw_identity = torch.zeros((1), dtype=torch.float32, device=device)
    roll_identity = torch.zeros((1), dtype=torch.float32, device=device)
    scale_identity = torch.ones((1), dtype=torch.float32, device=device) * 1.5

    use_identity_pose = True
    if use_identity_pose:
        t_s = t_identity
        pitch_s = pitch_identity
        yaw_s = yaw_identity
        roll_s = roll_identity
        scale_s = scale_identity
    else:
        t_s = x_s_info['t']
        pitch_s = x_s_info['pitch']
        yaw_s = x_s_info['yaw']
        roll_s = x_s_info['roll']
        scale_s = x_s_info['scale']
    t = t_s
    pitch = pitch_s
    yaw = yaw_s
    roll = roll_s
    scale = scale_s

    full_63_exp = torch.zeros(full_motion.shape[0], 63, device=device)
    for i, dim in enumerate(audio_model_config['latent_mask_1']):
        print(i, dim)
        full_63_exp[:, dim] = exp[:, i]
    full_motion = full_63_exp.reshape(-1, 63)

    x_d_list = []

    for i in tqdm(range(full_motion.shape[0]), desc="Generating x_d"):
        motion = full_motion[i].reshape(21, 3)

        # Initialize empty tensors

        # # Extract values from motion
        exp = motion #* exp_mask
        # pitch = pose[i, 0:1]
        # yaw = pose[i, 1:2]
        # roll = pose[i, 2:3]
        # t_x = pose[i, 3:4]
        # t_y = pose[i, 4:5]
        t = torch.tensor(t, device=device)

        x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t
        x_d_list.append(x_d_i.squeeze(0))

    x_d_batch = torch.stack(x_d_list, dim=0)
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size

    frames = []
    for i in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])

        batch_f_s = f_s_batch[start_idx:end_idx]
        batch_x_s = x_s_batch[start_idx:end_idx]
        batch_x_d = x_d_batch[start_idx:end_idx]

        out = warp_decode_func(batch_f_s, batch_x_s, batch_x_d)

        # Convert to numpy array
        batch_frames = (out['out'].permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        frames.extend(list(batch_frames))

    return frames

# Process the motion
all_frames = process_motion_batch(generated_motion, init_motion_prev, f_s, x_s, warp_decode)


In [None]:
# motion_latent_input_local = np.load('/mnt/c/Users/mjh/Downloads/x_212.npy')
# target_to_compare = np.load('/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00208/000.npy')
# to_print = motion_latent_input_local[:75, :]
# to_print_target = target_to_compare[:75, :]
# for i in range(55,56):
#     print(i, "gt", to_print[i])
#     print(i, "input", to_print_target[i])

In [None]:
# motion_latent_processed.shape, motion_gt.shape, motion_latent_input.shape

In [None]:

# Create video
import os

output_no_audio_path = 'output/audio_driven_video_no_audio.mp4'
output_video = 'output/audio_driven_video_output.mp4'

# Remove the files if they exist
if os.path.exists(output_no_audio_path):
    os.remove(output_no_audio_path)
if os.path.exists(output_video):
    os.remove(output_video)
fps = 25  # Adjust as needed

height, width, layers = all_frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(output_no_audio_path, fourcc, fps, (width, height))

for frame in all_frames:
    video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video.release()

# Add audio to the video using ffmpeg
input_video = output_no_audio_path
input_audio = used_audio_example  # Use the path to your audio file

ffmpeg_cmd = [
    'ffmpeg',
    '-i', input_video,
    '-i', input_audio,
    '-c:v', 'copy',
    '-c:a', 'aac',
    '-shortest',
    output_video
]

try:
    subprocess.run(ffmpeg_cmd, check=True)
    os.remove(output_no_audio_path)
    print(f"Video with audio saved to {output_video}")
except subprocess.CalledProcessError as e:
    print(f"Error adding audio to video: {e}")