### Full pipeline inference
1. load target image to drive
2. load audio
3. inference dit with audio to get latent
4. drive target with latent

### 1. Load target image

1.1 Prepare full live pipeline

In [None]:
import time

import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

'''
Common modules
'''

from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth


'''
Util functions
'''

def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

def euler_to_quaternion(pitch, yaw, roll):
    cy = torch.cos(yaw * 0.5)
    sy = torch.sin(yaw * 0.5)
    cp = torch.cos(pitch * 0.5)
    sp = torch.sin(pitch * 0.5)
    cr = torch.cos(roll * 0.5)
    sr = torch.sin(roll * 0.5)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy

    return torch.stack([w, x, y, z], dim=-1)

def quaternion_to_euler(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in radians.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor
    """
    # Extract the values from q
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    pitch = torch.where(
        torch.abs(sinp) >= 1,
        torch.sign(sinp) * torch.pi / 2,
        torch.asin(sinp)
    )

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return pitch, yaw, roll

def quaternion_to_euler_degrees(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in degrees.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor in degrees
    """
    pitch, yaw, roll = quaternion_to_euler(q)
    return torch.rad2deg(pitch), torch.rad2deg(yaw), torch.rad2deg(roll)


'''
Loading source related modules
'''
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out
'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)


1.2 Load single image

In [None]:
input_path = '/mnt/c/Users/mjh/Downloads/live_in/t1.jpg'
img_rgb = load_image_rgb(input_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)


### 2 Load audio

2.1 Prepare audio pipeline

In [None]:
import json
import torch
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import os
import numpy as np
from typing import List, Tuple
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Constants
MODEL_NAME = "facebook/wav2vec2-base-960h"
TARGET_SAMPLE_RATE = 16000
FRAME_RATE = 25
SECTION_LENGTH = 3
OVERLAP = 10

DB_ROOT = 'vox2-audio-tx'
LOG = 'log'
AUDIO = 'audio/audio'
OUTPUT_DIR = 'audio_encoder_output'
BATCH_SIZE = 16


# Move model and processor to global scope
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

def read_multiple_audios(paths, num_threads=12):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(load_and_process_audio, paths))
    return results


def read_json_and_form_paths(data,id_key):
    filenames=[]
    file_paths = []

    # Iterate through the nested structure
    for id_key, id_value in data.items():
        os.makedirs(os.path.join(DB_ROOT,OUTPUT_DIR,id_key), exist_ok=True)
        for url_key, url_value in id_value.items():
            for clip_id in url_value.keys():
                # Form the file path
                file_path = os.path.join(DB_ROOT,AUDIO,id_key, url_key, clip_id.replace('.txt', '.wav'))
                file_name = os.path.join(DB_ROOT,OUTPUT_DIR,id_key, url_key+'+'+clip_id.replace('.txt', ''))
                filenames.append(file_name)
                file_paths.append(file_path)

    return file_paths, filenames

def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)

    original_sample_rate = sample_rate

    if sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    print(file_path," waveform.shape ",waveform.shape)

    # Calculate section length and overlap in samples
    section_samples = SECTION_LENGTH * 16027
    overlap_samples = int(OVERLAP / FRAME_RATE * TARGET_SAMPLE_RATE)
    print('section_samples',section_samples,'overlap_samples',overlap_samples)

    # Pad if shorter than 3 seconds
    if waveform.shape[1] < section_samples:
        waveform = torch.nn.functional.pad(waveform, (0, section_samples - waveform.shape[1]))
        return [waveform.squeeze(0)], original_sample_rate

    # Split into sections with overlap
    sections = []
    start = 0

    print('starting to segment', file_path)
    while start < waveform.shape[1]:
        end = start + section_samples
        if end >= waveform.shape[1]:
            tmp=waveform[:, start:min(end, waveform.shape[1])]
            tmp = torch.nn.functional.pad(tmp, (0, section_samples - tmp.shape[1]))
            sections.append(tmp.squeeze(0))
            print(tmp.shape)
            break
        else:
            sections.append(waveform[:, start:min(end, waveform.shape[1])].squeeze(0))

        start = int(end - overlap_samples)


    return file_path, sections

def inference_process_wav_file(path):
    audio_path, segments = load_and_process_audio(path)
    segments = np.array(segments)

    inputs = wav2vec_processor(segments, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated


def process_wav_file(paths, output_paths, uid):
    device = torch.device(f"cuda")

    model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
    processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

    read_2_gpu_batch_size = 2048
    gpu_batch_size = BATCH_SIZE
    process_queue = torch.Tensor().to(device)

    audio_segments = read_multiple_audios(paths, num_threads=4)
    all_segments = []
    total_segments = 0

    audio_lengths = []
    output_fns = []

    for (audio_path, segments), output_fn in zip(audio_segments,output_paths):
        all_segments.extend(segments)
        segment_count = len(segments)
        total_segments += segment_count
        audio_lengths.append(segment_count)

        output_fns.append(output_fn)

    all_segments = np.array(all_segments)
    print(all_segments.size)

    read_data_2_gpu_pointer = 0
    pbar = tqdm(total=total_segments, desc=f"Processing {uid}")

    while read_data_2_gpu_pointer < total_segments:
        current_batch_size = min(read_2_gpu_batch_size, total_segments - read_data_2_gpu_pointer)

        batch_input = all_segments[read_data_2_gpu_pointer:read_data_2_gpu_pointer + current_batch_size]

        mini_batch_start = 0
        all_info = []
        while mini_batch_start < batch_input.shape[0]:
            mini_batch_end = min(mini_batch_start + gpu_batch_size, batch_input.shape[0])
            mini_batch = batch_input[mini_batch_start:mini_batch_end]

            inputs = processor(mini_batch, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

            with torch.no_grad():
                outputs = model(inputs)

            latent = outputs.last_hidden_state
            print('latent',latent.shape)
            seq_length = latent.shape[1]
            new_seq_length = int(seq_length * (FRAME_RATE / 50))  # Assuming Wav2Vec2 outputs at ~50Hz

            latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                            size=new_seq_length,
                                                            mode='linear',
                                                            align_corners=True).transpose(1,2)
            print('latent_features_interpolated',latent_features_interpolated.shape)
            all_info.append(latent_features_interpolated)

            mini_batch_start = mini_batch_end
        all_info_tensor = torch.cat(all_info, dim=0)

        process_queue = torch.cat((process_queue, all_info_tensor), dim=0)

        print(audio_lengths)
        while len(output_fns) > 0 and len(process_queue) >= audio_lengths[0]:
            current_output_fn = output_fns[0]
            current_segment_count = audio_lengths[0]

            audio_tensor = process_queue[:current_segment_count]
            np.save(current_output_fn, audio_tensor.cpu().numpy())
            print('save',current_output_fn)
            process_queue = process_queue[current_segment_count:]
            output_fns.pop(0)
            audio_lengths.pop(0)


        read_data_2_gpu_pointer += current_batch_size
        pbar.update(current_batch_size)

    pbar.close()


2.2 Load pretrained latents

In [None]:
def load_and_process_pair(audio_file, motion_file):
    # Load audio file
    audio_data = np.load(audio_file)

    # Load and process motion file
    motion_data = np.load(motion_file)
    pad_length = (65 - (motion_data.shape[0] - 10) % 65) % 65
    padded_data = np.pad(motion_data, ((0, pad_length), (0, 0)), mode='constant')

    data_without_first_10 = padded_data[10:]
    N = data_without_first_10.shape[0] // 65
    reshaped_data = data_without_first_10[:N*65].reshape(N, 65, 133)
    last_10 = reshaped_data[:, -10:, :]
    prev_10 = np.concatenate([padded_data[:10][None, :, :], last_10[:-1]], axis=0)
    final_windows = np.concatenate([prev_10, reshaped_data], axis=1)

    # Ensure audio and motion data have the same number of frames.
    # Prev lookup show 1 frame mismatch is common. In this case we only fix batch size mismatch
    min_frames = min(audio_data.shape[0], final_windows.shape[0])
    audio_data = audio_data[:min_frames]
    final_windows = final_windows[:min_frames]

    return audio_data, final_windows

2.3 Single input inference

In [None]:
input_dict = {
    "motion":[
        '/mnt/e/data/live_latent/motion_latent/id00078/P0OU4bFhwCI+00227.npy',
        '/mnt/e/data/live_latent/motion_latent/id00019/anX4gftNLoc+00140.npy',
        '/mnt/e/data/live_latent/motion_latent/id00012/aE4Om0EEiuk+00117.npy',
        '/mnt/e/data/live_latent/motion_latent/id09263/ARVWnF_NcCI+00002.npy',
        # HDTF
        '/mnt/e/data/diffposetalk_data/TFHP_raw/test_split/live_latent/TH_00226/001.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00028/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/live_latent/TH_00212/000.npy',

    ],
    "audio":[
        '/mnt/e/data/live_latent/audio_latent/id00078/P0OU4bFhwCI+00227.npy',
        '/mnt/e/data/live_latent/audio_latent/id00019/anX4gftNLoc+00140.npy',
        '/mnt/e/data/live_latent/audio_latent/id00012/aE4Om0EEiuk+00117.npy',
        '/mnt/e/data/live_latent/audio_latent/id09263/ARVWnF_NcCI+00002.npy',
        # HDTF
        '/mnt/e/data/diffposetalk_data/TFHP_raw/test_split/audio_latent/TH_00226/001.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/audio_latent/TH_00028/000.npy',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/train_split/audio_latent/TH_00212/000.npy',
    ],
    "wav":[
        '/mnt/c/Users/mjh/Downloads/audio_id00078_P0OU4bFhwCI_00227.wav',
        '/mnt/c/Users/mjh/Downloads/audio_id00019_anX4gftNLoc_00140.wav',
        '/mnt/c/Users/mjh/Downloads/audio_id00012_aE4Om0EEiuk_00117.wav',
        '/mnt/c/Users/mjh/Downloads/audio_id09263_ARVWnF_NcCI_00002.wav',
        # HDTF
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00226/001.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00028/000.wav',
        '/mnt/e/data/diffposetalk_data/TFHP_raw/audio/TH_00212/000.wav',
    ]
}

In [None]:
import numpy as np
index = 6
motion_latent_npy = input_dict['motion'][index]
audio_latent_path = input_dict['audio'][index]
training_audio_example = input_dict['wav'][index]
audio_latent, motion_latent_input = load_and_process_pair(audio_latent_path, motion_latent_npy)


# valid_audio_example = '/mnt/c/Users/mjh/Downloads/l8.wav'
used_audio_example = training_audio_example

# audio_latent = inference_process_wav_file(used_audio_example)
# audio_latent = np.load(audio_latent_path)
# audio_latent = torch.from_numpy(audio_latent).to(device)
# audio_latent = audio_latent[0].unsqueeze(0)
audio_latent = torch.from_numpy(audio_latent).to(device)
motion_latent_input = torch.from_numpy(motion_latent_input).to(device)
print(audio_latent.shape, motion_latent_input.shape)

### 3 Load DiT model

3.0 decide DiT type

In [None]:
from audio_dit.inference import example_inference
from audio_dit.dataset import process_motion_tensor

In [None]:
audio_latent_input = audio_latent
motion_latent_input = motion_latent_input
mask_1 = [i for i in range(21)]
motion_latent_processed, audio_latent_input, shape_in = process_motion_tensor(motion_latent_input, audio_latent_input, latent_type='exp', latent_mask_1=mask_1)
weight_path = 'audio_dit/output/checkpoint_epoch_8000_vanilla_exp_1/model.pth'
config_path = 'audio_dit/output/config.json'
audio_seq = audio_latent_input[:, 10:, :]
audio_prev = audio_latent_input[:, :10, :]
shape_in = motion_latent_processed[:, -1, :]
motion_prev = motion_latent_processed[:, :10, :]
motion_gt = motion_latent_processed[:, 10:-1, :]

# Print shapes of input tensors
print("Audio input shape:", audio_seq.shape)
print("Audio previous shape:", audio_prev.shape)
print("Shape input shape:", shape_in.shape)
print("Motion previous shape:", motion_prev.shape)
print("Motion ground truth shape:", motion_gt.shape)


In [None]:
print(os.path.exists(config_path))

3.1 Prepare DiT Model

In [None]:
# generated_motion = example_inference(config_path, weight_path, audio_seq, shape_in, motion_prev, audio_prev,
#                                      total_denoising_steps=5)
# loss = F.mse_loss(generated_motion, motion_gt)
# print(f"MSE Loss between generated motion and ground truth: {loss.item()}")

3.2 check shape

In [None]:
# generated_motion.shape, motion_prev.shape, motion_gt.shape

In [None]:
# if generated_motion.shape[1] != motion_gt.shape[1]:
#     generated_motion = torch.cat([motion_prev, generated_motion], dim=1)

In [None]:
# Calculate MSE loss between generated_motion and motion_latent_input


In [None]:
# diff = generated_motion - motion_gt
# # Create masks for differences and ground truth with absolute values >= 1e-3
# diff_mask = torch.abs(diff) >= 1e-3
# gt_mask = torch.abs(motion_gt) < 1e-5
# total_elements = torch.numel(diff)

# # Count overlaps and non-overlaps
# overlap_count = torch.sum(diff_mask & gt_mask).item()
# diff_only_count = torch.sum(diff_mask & ~gt_mask).item()
# gt_only_count = torch.sum(~diff_mask & gt_mask).item()

# # Apply the mask to the diff tensor
# large_diffs = diff * diff_mask

# # Add the large differences to the generated_motion
# generated_motion = generated_motion - large_diffs * 0.8

# print(f"Number of large differences (>= 1e-3) adjusted: {torch.sum(diff_mask).item()} out of {total_elements}")
# print(f"Number of large values (>= 1e-3) in ground truth: {torch.sum(gt_mask).item()} out of {total_elements}")
# print(f"Number of overlapping large values: {overlap_count}")
# print(f"Number of large differences not in ground truth: {diff_only_count}")
# print(f"Number of large ground truth values not in differences: {gt_only_count}")
# print(f"Max difference after adjustment: {torch.max(torch.abs(generated_motion - motion_gt)).item():.6f}")
# print(f"Min difference after adjustment: {torch.min(torch.abs(generated_motion - motion_gt)).item():.6f}")


In [None]:
for i in range(motion_gt.shape[1]):
    print(f"Frame {i}")
    print("Generated motion:", generated_motion[0, i])
    print("Ground truth motion:", motion_gt[0, i])
    print("Difference:", generated_motion[0, i] - motion_gt[0, i])
    print("Per frame MSE loss:", F.mse_loss(generated_motion[0, i], motion_gt[0, i]).item())
    print()

### 4. Do Render

In [None]:
x_s_info.keys()

In [None]:
print(f_s.shape, x_s.shape, x_c_s.shape)
print("Input scale is ", x_s_info['scale'])


I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)

In [None]:
import cv2
import numpy as np
import torch

t = x_s_info['t']
pitch = x_s_info['pitch']
yaw = x_s_info['yaw']
roll = x_s_info['roll']
# scale = motion_latent_input_local[j, 132:133]
scale = x_s_info['scale']

# Extract values from motion
exp_identity = torch.zeros_like(x_s_info['exp'])
exp_gen = generated_motion[0, 11, :].reshape(21, 3)
# exp = x_s_info['exp']
exp = exp_gen

x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t

# Combine into x_d_i
# x_d_i = motions[i, j].unsqueeze(0).reshape(-1, 21, 3)

# Generate frame
out = warp_decode(f_s, x_s, x_d_i)

# Convert to numpy array
frame_index = (out['out'][0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)


In [None]:
generated_motion.shape

In [None]:

exp_identity = torch.zeros_like(x_s_info['exp']).squeeze()
exp_identity

In [None]:
motion_gt = motion_gt.reshape(-1, 63)

In [None]:
motion_gt.shape

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
motion_gt = motion_gt[0]
# Calculate the global min and max for each axis
all_keypoints = np.array([motion_gt[frame].reshape(21, 3).squeeze().cpu().numpy() for frame in range(0, motion_gt.shape[0])])
global_min = all_keypoints.min(axis=(0, 1))
global_max = all_keypoints.max(axis=(0, 1))
exp_identity = torch.zeros_like(x_s_info['exp'])
# Generate frame
t = x_s_info['t']
pitch = x_s_info['pitch']
yaw = x_s_info['yaw']
roll = x_s_info['roll']
scale = x_s_info['scale']

# Create a function to update the plot
def update_plot(frame_index, elev, azim, **point_selections):
    # Extract keypoints from all_keypoints
    keypoints = all_keypoints[frame_index]

    # Set exp to identity plus the currently selected points
    exp = exp_identity.clone()
    selected_points = [int(point.split('_')[1]) for point, selected in point_selections.items() if selected]

    for point in selected_points:
        exp[0][point] = torch.tensor(keypoints[point])

    x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t

    out = warp_decode(f_s, x_s, x_d_i)

    # Convert to numpy array
    frame_img = (out['out'][0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    fig = plt.figure(figsize=(20, 10))

    # 3D plot
    ax1 = fig.add_subplot(121, projection='3d')

    # Plot the keypoints in 3D
    ax1.scatter(keypoints[:, 0], keypoints[:, 1], keypoints[:, 2], c='r', s=20)

    # Highlight selected points
    if selected_points:
        selected_keypoints = keypoints[selected_points]
        ax1.scatter(selected_keypoints[:, 0], selected_keypoints[:, 1], selected_keypoints[:, 2], c='b', s=40)

    # Add labels to each keypoint
    for i, (x, y, z) in enumerate(keypoints):
        ax1.text(x, y, z, str(i), fontsize=8)

    # Set labels for each axis
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')

    # Set title
    ax1.set_title(f'3D Keypoints - Frame {frame_index}')

    # Set fixed axis limits
    ax1.set_xlim(global_min[0], global_max[0])
    ax1.set_ylim(global_min[1], global_max[1])
    ax1.set_zlim(global_min[2], global_max[2])

    # Set the view angle
    ax1.view_init(elev=elev, azim=azim)

    # Generated frame
    ax2 = fig.add_subplot(122)
    ax2.imshow(frame_img)
    ax2.axis('off')
    ax2.set_title('Generated Frame')

    plt.tight_layout()
    plt.show()
    return fig

# Create sliders for frame selection and view angle
frame_slider = widgets.IntSlider(value=0, min=0, max=motion_gt.shape[0]-1, step=1, description='Frame:')
elev_slider = widgets.FloatSlider(value=20, min=0, max=90, step=1, description='Elevation:')
azim_slider = widgets.FloatSlider(value=45, min=-180, max=180, step=1, description='Azimuth:')

# Create checkboxes for point selection
point_checkboxes = [widgets.Checkbox(value=False, description=f'Point {i}') for i in range(21)]

# Create the interactive plot
interactive_plot = widgets.interactive(update_plot,
                                       frame_index=frame_slider,
                                       elev=elev_slider,
                                       azim=azim_slider,
                                       **{f'point_{i}': checkbox for i, checkbox in enumerate(point_checkboxes)})

# Display the interactive plot
display(widgets.VBox([interactive_plot, widgets.HBox(point_checkboxes)]))

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch

# Calculate the global min and max for each axis
all_keypoints = np.array([motion_gt[frame, :].reshape(21, 3).squeeze().cpu().numpy() for frame in range(0, motion_gt.shape[0])])
global_min = all_keypoints.min(axis=(0, 1))
global_max = all_keypoints.max(axis=(0, 1))
exp_identity = torch.zeros_like(x_s_info['exp'])
# Generate frame
t = x_s_info['t']
pitch = x_s_info['pitch']
yaw = x_s_info['yaw']
roll = x_s_info['roll']
scale = x_s_info['scale']

# Create a function to update the plot
def update_plotNew(frame_index, elev, azim, **dim_values):
    # Extract keypoints from all_keypoints
    keypoints = all_keypoints[frame_index]

    # Set exp to identity plus the currently selected dimensions
    exp = exp_identity.clone()
    exp = exp.reshape(-1, 63)
    for dim, value in dim_values.items():
        if value is not None:
            dim_index = int(dim.split('_')[1])
            exp[0][dim_index] = torch.tensor(value)
    exp = exp.reshape(-1, 21, 3)
    x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t

    out = warp_decode(f_s, x_s, x_d_i)

    # Convert to numpy array
    frame_img = (out['out'][0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    fig = plt.figure(figsize=(20, 10))

    # 3D plot
    ax1 = fig.add_subplot(121, projection='3d')

    # Plot the keypoints in 3D
    ax1.scatter(keypoints[:, 0], keypoints[:, 1], keypoints[:, 2], c='r', s=20)

    # Highlight modified dimensions
    modified_dims = [int(dim.split('_')[1]) for dim, value in dim_values.items() if value is not None]
    if modified_dims:
        modified_keypoints = keypoints[np.array(modified_dims) // 3]
        ax1.scatter(modified_keypoints[:, 0], modified_keypoints[:, 1], modified_keypoints[:, 2], c='b', s=40)

    # Add labels to each keypoint
    for i, (x, y, z) in enumerate(keypoints):
        ax1.text(x, y, z, str(i), fontsize=8)

    # Set labels for each axis
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')

    # Set title
    ax1.set_title(f'3D Keypoints - Frame {frame_index}')

    # Set fixed axis limits
    ax1.set_xlim(global_min[0], global_max[0])
    ax1.set_ylim(global_min[1], global_max[1])
    ax1.set_zlim(global_min[2], global_max[2])

    # Set the view angle
    ax1.view_init(elev=elev, azim=azim)

    # Generated frame
    ax2 = fig.add_subplot(122)
    ax2.imshow(frame_img)
    ax2.axis('off')
    ax2.set_title('Generated Frame')

    plt.tight_layout()
    plt.show()
    return fig

# Create sliders for frame selection and view angle
frame_slider = widgets.IntSlider(value=0, min=0, max=motion_gt.shape[0]-1, step=1, description='Frame:')
elev_slider = widgets.FloatSlider(value=20, min=0, max=90, step=1, description='Elevation:')
azim_slider = widgets.FloatSlider(value=45, min=-180, max=180, step=1, description='Azimuth:')

# Create input boxes for each dimension
dim_inputs = [widgets.FloatText(value=None, description=f'Dim {i}:', continuous_update=False) for i in range(63)]

# Create the interactive plot
interactive_plot = widgets.interactive(update_plotNew,
                                       frame_index=frame_slider,
                                       elev=elev_slider,
                                       azim=azim_slider,
                                       **{f'dim_{i}': input_box for i, input_box in enumerate(dim_inputs)})

# Display the interactive plot
display(widgets.VBox([
    interactive_plot,
    widgets.HBox([widgets.VBox(dim_inputs[:21]), widgets.VBox(dim_inputs[21:42]), widgets.VBox(dim_inputs[42:])])
]))