In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
import json
import yaml
import glob
import tyro
import imageio
import subprocess
from tqdm import tqdm
from typing import List, Tuple
from rich.progress import track
from threading import Thread
from queue import Queue
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import cv2
import torch
import torchaudio
import torchvision
import torchvision.transforms as transforms
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
from transformers import Wav2Vec2Model, Wav2Vec2Processor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Live Portrait
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.utils.cropper import Cropper
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth

# DiT
from audio_dit.inference import InferenceManager, get_model
from audio_dit.dataset import load_and_process_pair

### Config

In [4]:
# Audio model
MODEL_NAME = "facebook/wav2vec2-base-960h"
TARGET_SAMPLE_RATE = 16000
FRAME_RATE = 25
SECTION_LENGTH = 3
OVERLAP = 10

DB_ROOT = 'vox2-audio-tx'
LOG = 'log'
AUDIO = 'audio/audio'
OUTPUT_DIR = 'audio_encoder_output'
BATCH_SIZE = 16

# DiT model
config_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/config.json'
weight_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/model.pth'

# input
input_image_path = 'D:/Projects/Upenn_CIS_5650/final-project/data/img/test5.jpg'
input_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/data/audio/test2.wav'

# output
output_no_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test5_no_audio.mp4'
output_video = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test5_with_audio.mp4'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### Live Portrait Pipeline

In [6]:
def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
print("Compile complete")

'''
Util functions
'''

def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

def euler_to_quaternion(pitch, yaw, roll):
    cy = torch.cos(yaw * 0.5)
    sy = torch.sin(yaw * 0.5)
    cp = torch.cos(pitch * 0.5)
    sp = torch.sin(pitch * 0.5)
    cr = torch.cos(roll * 0.5)
    sr = torch.sin(roll * 0.5)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy

    return torch.stack([w, x, y, z], dim=-1)

def quaternion_to_euler(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in radians.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor
    """
    # Extract the values from q
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    pitch = torch.where(
        torch.abs(sinp) >= 1,
        torch.sign(sinp) * torch.pi / 2,
        torch.asin(sinp)
    )

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return pitch, yaw, roll

def quaternion_to_euler_degrees(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in degrees.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor in degrees
    """
    pitch, yaw, roll = quaternion_to_euler(q)
    return torch.rad2deg(pitch), torch.rad2deg(yaw), torch.rad2deg(roll)


'''
Loading source related modules
'''
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out
'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

Compile complete


### Audio Pipeline

In [7]:
# Move model and processor to global scope
wav2vec_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

def read_multiple_audios(paths, num_threads=12):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(load_and_process_audio, paths))
    return results


def read_json_and_form_paths(data,id_key):
    filenames=[]
    file_paths = []

    # Iterate through the nested structure
    for id_key, id_value in data.items():
        os.makedirs(os.path.join(DB_ROOT,OUTPUT_DIR,id_key), exist_ok=True)
        for url_key, url_value in id_value.items():
            for clip_id in url_value.keys():
                # Form the file path
                file_path = os.path.join(DB_ROOT,AUDIO,id_key, url_key, clip_id.replace('.txt', '.wav'))
                file_name = os.path.join(DB_ROOT,OUTPUT_DIR,id_key, url_key+'+'+clip_id.replace('.txt', ''))
                filenames.append(file_name)
                file_paths.append(file_path)

    return file_paths, filenames

def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)

    original_sample_rate = sample_rate

    if sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    print(file_path," waveform.shape ",waveform.shape)

    # Calculate section length and overlap in samples
    section_samples = SECTION_LENGTH * 16027
    overlap_samples = int(OVERLAP / FRAME_RATE * TARGET_SAMPLE_RATE)
    print('section_samples',section_samples,'overlap_samples',overlap_samples)

    # Pad if shorter than 3 seconds
    if waveform.shape[1] < section_samples:
        waveform = torch.nn.functional.pad(waveform, (0, section_samples - waveform.shape[1]))
        return [waveform.squeeze(0)], original_sample_rate

    # Split into sections with overlap
    sections = []
    start = 0

    print('starting to segment', file_path)
    while start < waveform.shape[1]:
        end = start + section_samples
        if end >= waveform.shape[1]:
            tmp=waveform[:, start:min(end, waveform.shape[1])]
            tmp = torch.nn.functional.pad(tmp, (0, section_samples - tmp.shape[1]))
            sections.append(tmp.squeeze(0))
            print(tmp.shape)
            break
        else:
            sections.append(waveform[:, start:min(end, waveform.shape[1])].squeeze(0))

        start = int(end - overlap_samples)


    return file_path, sections

def inference_process_wav_file(path):
    audio_path, segments = load_and_process_audio(path)
    print(audio_path,segments)
    segments = np.array(segments)

    inputs = wav2vec_processor(segments, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated


def process_wav_file(paths, output_paths, uid):
    device = torch.device(f"cuda")

    model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
    processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

    read_2_gpu_batch_size = 2048
    gpu_batch_size = BATCH_SIZE
    process_queue = torch.Tensor().to(device)

    audio_segments = read_multiple_audios(paths, num_threads=4)
    all_segments = []
    total_segments = 0

    audio_lengths = []
    output_fns = []

    for (audio_path, segments), output_fn in zip(audio_segments,output_paths):
        all_segments.extend(segments)
        segment_count = len(segments)
        total_segments += segment_count
        audio_lengths.append(segment_count)

        output_fns.append(output_fn)

    all_segments = np.array(all_segments)
    print(all_segments.size)

    read_data_2_gpu_pointer = 0
    pbar = tqdm(total=total_segments, desc=f"Processing {uid}")

    while read_data_2_gpu_pointer < total_segments:
        current_batch_size = min(read_2_gpu_batch_size, total_segments - read_data_2_gpu_pointer)

        batch_input = all_segments[read_data_2_gpu_pointer:read_data_2_gpu_pointer + current_batch_size]

        mini_batch_start = 0
        all_info = []
        while mini_batch_start < batch_input.shape[0]:
            mini_batch_end = min(mini_batch_start + gpu_batch_size, batch_input.shape[0])
            mini_batch = batch_input[mini_batch_start:mini_batch_end]

            inputs = processor(mini_batch, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

            with torch.no_grad():
                outputs = model(inputs)

            latent = outputs.last_hidden_state
            print('latent',latent.shape)
            seq_length = latent.shape[1]
            new_seq_length = int(seq_length * (FRAME_RATE / 50))  # Assuming Wav2Vec2 outputs at ~50Hz

            latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                            size=new_seq_length,
                                                            mode='linear',
                                                            align_corners=True).transpose(1,2)
            print('latent_features_interpolated',latent_features_interpolated.shape)
            all_info.append(latent_features_interpolated)

            mini_batch_start = mini_batch_end
        all_info_tensor = torch.cat(all_info, dim=0)

        process_queue = torch.cat((process_queue, all_info_tensor), dim=0)

        print(audio_lengths)
        while len(output_fns) > 0 and len(process_queue) >= audio_lengths[0]:
            current_output_fn = output_fns[0]
            current_segment_count = audio_lengths[0]

            audio_tensor = process_queue[:current_segment_count]
            np.save(current_output_fn, audio_tensor.cpu().numpy())
            print('save',current_output_fn)
            process_queue = process_queue[current_segment_count:]
            output_fns.pop(0)
            audio_lengths.pop(0)


        read_data_2_gpu_pointer += current_batch_size
        pbar.update(current_batch_size)

    pbar.close()

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### DiT Model

In [8]:
print("config_path exists:", os.path.exists(config_path))
audio_model_config = json.load(open(config_path))
inference_manager = get_model(config_path, weight_path, device)

config_path exists: True
Model checkpoint loaded from D:/Projects/Upenn_CIS_5650/final-project/config/model.pth


### Process Input Image

In [9]:
img_rgb = load_image_rgb(input_image_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)

print(f"I_s shape {I_s.shape}")
print(f"x_c_s shape {x_c_s.shape}")
print(f"x_s shape {x_s.shape}")
print(f"f_s shape {f_s.shape}")

I_s shape torch.Size([1, 3, 256, 256])
x_c_s shape torch.Size([1, 21, 3])
x_s shape torch.Size([1, 21, 3])
f_s shape torch.Size([1, 32, 16, 64, 64])


### Process Input Audio

In [10]:
torchaudio.set_audio_backend("soundfile")
custom_audio_latent = inference_process_wav_file(input_audio_path)

custom_audio_latent.shape

D:/Projects/Upenn_CIS_5650/final-project/data/audio/test2.wav  waveform.shape  torch.Size([1, 662792])
section_samples 48081 overlap_samples 6400
starting to segment D:/Projects/Upenn_CIS_5650/final-project/data/audio/test2.wav
torch.Size([1, 48081])
D:/Projects/Upenn_CIS_5650/final-project/data/audio/test2.wav [tensor([0.0031, 0.0042, 0.0044,  ..., 0.0160, 0.0171, 0.0226]), tensor([0.0125, 0.0008, 0.0224,  ..., 0.0056, 0.0055, 0.0058]), tensor([0.0107, 0.0103, 0.0101,  ..., 0.0105, 0.0109, 0.0089]), tensor([-0.0445, -0.0396, -0.0345,  ..., -0.0307, -0.0460, -0.0543]), tensor([-0.0225, -0.0224, -0.0224,  ...,  0.0096,  0.0099,  0.0101]), tensor([0.0121, 0.0124, 0.0126,  ..., 0.0049, 0.0090, 0.0081]), tensor([-0.0153, -0.0097, -0.0081,  ...,  0.0958,  0.0839,  0.0701]), tensor([-0.0042, -0.0037, -0.0054,  ..., -0.0024, -0.0024, -0.0019]), tensor([-0.0078, -0.0073, -0.0069,  ..., -0.0206, -0.0178, -0.0157]), tensor([-0.0185, -0.0222, -0.0239,  ..., -0.0022, -0.0037, -0.0038]), tensor([ 0

  torchaudio.set_audio_backend("soundfile")
  attn_output = torch.nn.functional.scaled_dot_product_attention(


torch.Size([16, 75, 768])

### Generate Motion based on audio

In [11]:
used_audio_example = input_audio_path
audio_latent = custom_audio_latent

audio_latent_input = audio_latent
latent_mask_used=audio_model_config['latent_mask_1']
latent_bound = audio_model_config['latent_bound']
print(latent_mask_used, latent_bound)

audio_seq = audio_latent_input[:, 10:, :]
audio_prev = audio_latent_input[:, :10, :]

print("Audio input shape:", audio_seq.shape)
print("Audio previous shape:", audio_prev.shape)

motion_prev = torch.zeros(audio_latent.shape[0], 10, 6, device=device)

mouth_open_ratio_val = 0.25
mouth_open_ratio_input = torch.tensor([mouth_open_ratio_val], device=device).unsqueeze(0)
out_motion = torch.tensor([], device=device)
B, T, audio_dim = audio_seq.shape
motion_dim = audio_model_config['x_dim']
shape_in = x_c_s.reshape(1, -1).to(device)
this_audio_prev = torch.zeros(1, 10, audio_dim, device=device)
this_motion_prev = torch.zeros(1, 10, motion_dim , device=device)
motion_prev = torch.zeros(1, 10, motion_dim , device=device)
print("Audio input shape:", audio_seq.shape)
for batch_index in range(0, audio_seq.shape[0]):
    generated_motion = inference_manager.inference(audio_seq[batch_index:batch_index+1],
                                                shape_in, this_motion_prev, this_audio_prev,
                                                cfg_scale=0.25,
                                                mouth_open_ratio = mouth_open_ratio_input,
                                                denoising_steps=10)
    this_motion_prev = generated_motion[:, -10:, :]
    this_audio_prev = audio_seq[batch_index:batch_index+1, -10:, :]

    generated_motion = generated_motion - torch.mean(generated_motion, dim=-1, keepdim=True)
    out_motion = torch.cat((out_motion, generated_motion), dim=0)

generated_motion = out_motion
generated_motion.shape, motion_prev.shape

[4, 6, 7, 22, 33, 34, 40, 43, 45, 46, 48, 51, 52, 53, 57, 58, 59, 60, 61, 62] [-0.05029296875, 0.0857086181640625, -0.07587742805480957, 0.058624267578125, -0.0004341602325439453, 0.00019466876983642578, -0.038482666015625, 0.0345458984375, -0.030120849609375, 0.038360595703125, -3.0279159545898438e-05, 1.3887882232666016e-05, -0.0364990234375, 0.036102294921875, -0.043212890625, 0.046844482421875, -4.3332576751708984e-05, 1.8775463104248047e-05, -0.03326416015625, 0.057373046875, -0.03460693359375, 0.031707763671875, -0.0001958608627319336, 0.0005192756652832031, -0.0728759765625, 0.0587158203125, -0.04840087890625, 0.039642333984375, -0.00025916099548339844, 0.00048089027404785156, -0.09722900390625, 0.12469482421875, -0.1556396484375, 0.09326171875, -0.00018024444580078125, 0.00037860870361328125, -0.0279384758323431, 0.010650634765625, -0.039306640625, 0.03802490234375, -1.049041748046875e-05, 3.6954879760742188e-06, -0.032989501953125, 0.044281005859375, -0.037261962890625, 0.0433

(torch.Size([16, 65, 20]), torch.Size([1, 10, 20]))

### Stream Generated Frames Using Generated Motion

In [None]:
# Stream frame by frame

from time import sleep

def process_motion_stream(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func, output_buffer):
    # Process generated motion feature
    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion

    # Process generated x_d (full motion feature)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    #output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)

    # Process frame
    for i in range(x_d_batch.shape[0]):
        # Step 1 process current frame through the warp_decode_func
        out = warp_decode_func(f_s, x_s, x_d_batch[i].unsqueeze(0))

        # Step 2 write to output pinned memory buffer
        output_buffer.put(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8).squeeze(0))

    # Terminate
    output_buffer.put(None)

def display_frames(output_buffer, pre_time):
    while True:
        # Retrieve the next frame from the buffer
        frame_tensor = output_buffer.get()

        # End of processing
        if frame_tensor is None:
            break

        # Transfer the frame from pinned memory back to CPU
        frame = frame_tensor.cpu().numpy()
        result_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        # Ensure 25 fps
        if (datetime.now() - pre_time).total_seconds() < 0.024:
            sleep(0.024 - (datetime.now() - pre_time).total_seconds())

        # Display the frame using OpenCV
        print(f"time to display {datetime.now() - pre_time}")
        cv2.imshow("Video Stream", result_bgr)
        cv2.waitKey(1)
        pre_time = datetime.now()

def stream_frames(generated_motion, motion_prev, f_s, x_s, warp_decode):
    # Record start time
    start_time = datetime.now()

    # Output buffer tracking
    output_buffer = Queue(maxsize = 5000)

    # Start frame processing thread
    processing_thread = Thread(target = process_motion_stream, args=(generated_motion, motion_prev, f_s, x_s, warp_decode, output_buffer))
    processing_thread.start()

    # Start display frame
    display_frames(output_buffer, start_time)

    # Wait until terminate
    processing_thread.join()

    # Cleanup
    cv2.destroyAllWindows()

In [13]:
stream_frames(generated_motion, motion_prev, f_s, x_s, warp_decode)

time to display 0:00:00.264128
time to display 0:00:00.024544
time to display 0:00:00.025197
time to display 0:00:00.024931
time to display 0:00:00.024653
time to display 0:00:00.024510
time to display 0:00:00.024522
time to display 0:00:00.025087
time to display 0:00:00.024516
time to display 0:00:00.024452
time to display 0:00:00.025012
time to display 0:00:00.024514
time to display 0:00:00.025086
time to display 0:00:00.024619
time to display 0:00:00.024551
time to display 0:00:00.025064
time to display 0:00:00.025029
time to display 0:00:00.024535
time to display 0:00:00.024537
time to display 0:00:00.025011
time to display 0:00:00.024514
time to display 0:00:00.025017
time to display 0:00:00.024537
time to display 0:00:00.025039
time to display 0:00:00.024506
time to display 0:00:00.024528
time to display 0:00:00.024548
time to display 0:00:00.024696
time to display 0:00:00.025049
time to display 0:00:00.024510
time to display 0:00:00.024595
time to display 0:00:00.024509
time to 

### Save Output as MP4

In [13]:
# Remove the files if they exist
if os.path.exists(output_no_audio_path):
    os.remove(output_no_audio_path)
if os.path.exists(output_video):
    os.remove(output_video)
fps = 25  # Adjust as needed

height, width, layers = all_frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(output_no_audio_path, fourcc, fps, (width, height))

for frame in all_frames:
    video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video.release()

# Add audio to the video using ffmpeg
input_video = output_no_audio_path
input_audio = used_audio_example  # Use the path to your audio file

ffmpeg_cmd = [
    'ffmpeg',
    '-i', input_video,
    '-i', input_audio,
    '-c:v', 'copy',
    '-c:a', 'aac',
    '-shortest',
    output_video
]

try:
    subprocess.run(ffmpeg_cmd, check=True)
    os.remove(output_no_audio_path)
    print(f"Video with audio saved to {output_video}")
except subprocess.CalledProcessError as e:
    print(f"Error adding audio to video: {e}")

Video with audio saved to D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test5_with_audio.mp4
