In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
import json
import yaml
import glob
import tyro
from time import sleep
import imageio
import subprocess
from tqdm import tqdm
from typing import List, Tuple
from rich.progress import track
from threading import Thread
from queue import Queue
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from scipy.signal import savgol_filter
import cv2
import torch
import torchaudio
import torchvision
import torchvision.transforms as transforms
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import Wav2Vec2Model, Wav2Vec2Processor
torchaudio.set_audio_backend("soundfile")

  from .autonotebook import tqdm as notebook_tqdm
  torchaudio.set_audio_backend("soundfile")


In [3]:
# Live Portrait
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.utils.cropper import Cropper
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth

# DiT
from audio_dit.inference import InferenceManager, get_model
from audio_dit.dataset import load_and_process_pair
from audio_dit.dataset import process_motion_tensor

In [4]:
# wave2vec
MODEL_NAME = "facebook/wav2vec2-base-960h"
TARGET_SAMPLE_RATE = 16000
FRAME_RATE = 25
SECTION_LENGTH = 3
OVERLAP = 10

DB_ROOT = 'vox2-audio-tx'
LOG = 'log'
AUDIO = 'audio/audio'
OUTPUT_DIR = 'audio_encoder_output'
BATCH_SIZE = 16

prev_context_len = 67
gen_len_per_window = 8

# DiT model
config_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/config.json'
weight_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/model.pth'
motion_mean_exp_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/000.npy'

cfg_s = 0.65
mouth_ratio = 0.25
subtract_avg_motion = False

headpose_bound_list = [-21, 25, -30, 30, -23, 23, -0.3, 0.3, -0.3, 0.28]

# input
input_image_path = 'D:/Projects/Upenn_CIS_5650/final-project/data/img/test5.jpg'
input_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/data/audio/test3.wav'

# output
output_no_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/no_audio.mp4'
output_video = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/5_3_full_10_65_with_audio.mp4'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
# LivePortrait Pipeline

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
print("Compile complete")

'''
Main function for inference
'''

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

'''
Main module for inference
'''

model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

Compile complete


In [7]:
# wav2vec Pipeline

def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)

    original_sample_rate = sample_rate

    if sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    # print(file_path," waveform.shape ",waveform.shape)

    # Calculate section length and overlap in samples
    section_samples = SECTION_LENGTH * 16027
    overlap_samples = int(OVERLAP / FRAME_RATE * TARGET_SAMPLE_RATE)
    # print('section_samples',section_samples,'overlap_samples',overlap_samples)

    # pad 10 overlap at the beginning
    waveform = torch.nn.functional.pad(waveform, (overlap_samples, 0))
    # Pad if shorter than 3 seconds
    if waveform.shape[1] < section_samples:
        waveform = torch.nn.functional.pad(waveform, (0, section_samples - waveform.shape[1]))
        return [waveform.squeeze(0)], original_sample_rate

    # Split into sections with overlap
    sections = []
    start = 0

    # print('starting to segment', file_path)
    while start < waveform.shape[1]:
        end = start + section_samples
        if end >= waveform.shape[1]:
            tmp=waveform[:, start:min(end, waveform.shape[1])]
            tmp = torch.nn.functional.pad(tmp, (0, section_samples - tmp.shape[1]))
            sections.append(tmp.squeeze(0))
            # print(tmp.shape)
            break
        else:
            sections.append(waveform[:, start:min(end, waveform.shape[1])].squeeze(0))

        start = int(end - overlap_samples)


    return file_path, sections

def inference_process_wav_file(path):
    audio_path, segments = load_and_process_audio(path)
    # print(audio_path,segments)
    segments = np.array(segments)

    inputs = wav2vec_processor(segments, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated

def autoregress_load_and_process_audio(file_path):
    first_segment_prev_length = 10
    first_segment_main_length = 65
    remaining_segment_prev_length = prev_context_len
    remaining_segment_main_length = gen_len_per_window

    # below is the same as load_and_process_audio
    waveform, og_sample_rate = torchaudio.load(file_path)

    if og_sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, og_sample_rate, TARGET_SAMPLE_RATE)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # define sample count
    per_window_samples = SECTION_LENGTH * 16027
    first_prev_samples = int(first_segment_prev_length * 16027 / FRAME_RATE)
    remaining_overlap_samples = int(remaining_segment_prev_length / FRAME_RATE * TARGET_SAMPLE_RATE)
    # pad 10 overlap at the beginning
    total_frame = int(waveform.shape[1] / TARGET_SAMPLE_RATE * FRAME_RATE) + 1
    waveform = torch.nn.functional.pad(waveform, (first_prev_samples, 0))

    # split into windows with overlap
    windows = []
    start = 0

    total_sample_count = waveform.shape[1]
    while start < total_sample_count:
        end = start + per_window_samples
        if end >= total_sample_count: # need to pad since last exceeds total sample count
            tmp = waveform[:, start:min(end, total_sample_count)]
            tmp = torch.nn.functional.pad(tmp, (0, per_window_samples - tmp.shape[1]))
            windows.append(tmp.squeeze(0))
            break
        else:
            windows.append(waveform[:, start:min(end, total_sample_count)].squeeze(0))
        start = int(end - remaining_overlap_samples)

    return windows, total_frame

def autoregress_inference_process_wav_file(path):
    windows, total_frame = autoregress_load_and_process_audio(path)
    print(f"total frame {total_frame}")
    windows = np.array(windows)

    inputs = wav2vec_processor(windows, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)
    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated, total_frame

# Move model and processor to global scope
wav2vec_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# DiT Pipeline

audio_model_config = json.load(open(config_path))
inference_manager = get_model(config_path, weight_path, device)

Model checkpoint loaded from D:/Projects/Upenn_CIS_5650/final-project/config/model.pth


In [9]:
# Headpose tuning

# blink noise
def fixed_blink_noise(motion_tensor, motion_gt):
    eye_indices = torch.tensor([5, 9], device=motion_tensor.device)
    rest_latent = torch.tensor([-0.005], device=motion_tensor.device)
    spikes = torch.tensor([0.0015, 0.0085, 0.011, 0.0075, 0.002], device=motion_tensor.device)
    freeze_index = [0, 1, 2, 4, 6, 8, 10]
    freeze_index = torch.tensor(freeze_index, device=motion_tensor.device)
    period = 12
    period_counter = 0
    reset_flag = False
    for i in range(motion_tensor.shape[0]):
        in_period_index = i % period
        if in_period_index < 5 and period_counter >= period:
            for eye_index in eye_indices:
                motion_tensor[i, eye_index] = spikes[in_period_index]
                motion_tensor[i, eye_index] += torch.randn_like(motion_tensor[i, eye_index]) * 0.002
            if in_period_index == 4:
                reset_flag = True
                period_counter = 0
        else:
            for eye_index in eye_indices:
                motion_tensor[i, eye_index] = rest_latent
                motion_tensor[i, eye_index] += torch.randn_like(motion_tensor[i, eye_index]) * 0.0002
            period_counter += 1
        if reset_flag:
            reset_flag = False
            period = torch.randint(15, 18, (1,)).item()
        for f in freeze_index:
            # motion_tensor[i, f] = motion_gt[i, f]|
            motion_tensor[i, f] = 0
    return motion_tensor

# normalize headpose
def normalize_pose(full_motion, headpose_bound):
    assert headpose_bound is not None and len(headpose_bound) % 2 == 0
    headpose_bound = torch.tensor(headpose_bound)
    headpose_bound = headpose_bound.reshape(headpose_bound.shape[0] // 2, 2)

    # Assuming full_motion is a tensor of shape (batch_size, sequence_length, num_features)
    # and the last 5 features are the ones to be normalized
    last_5_features = full_motion[:, :, -5:]

    # Normalize each of the last 5 features
    for i in range(5):
        lower_bound = headpose_bound[i][0]
        upper_bound = headpose_bound[i][1]

        # Clamp the values within the specified bounds
        clamped = torch.clamp(last_5_features[:, :, i], min=lower_bound, max=upper_bound)

        # Normalize to the range [-0.05, 0.05]
        normalized = (clamped - lower_bound) / (upper_bound - lower_bound) * 0.1 - 0.05

        # Update the last 5 features with the normalized values
        last_5_features[:, :, i] = normalized

    # Update the full_motion tensor with the normalized last 5 features
    full_motion[:, :, -5:] = last_5_features

    return full_motion

def reverse_normalize_pose(normalized_motion, headpose_bound):
    assert headpose_bound is not None and len(headpose_bound) % 2 == 0
    headpose_bound = torch.tensor(headpose_bound)
    headpose_bound = headpose_bound.reshape(headpose_bound.shape[0] // 2, 2)

    # Assuming normalized_motion is a tensor of shape (batch_size, sequence_length, num_features)
    # and the last 5 features are the ones to be reversed
    last_5_features = normalized_motion[:, :, -5:]

    # Reverse normalization for each of the last 5 features
    for i in range(5):
        lower_bound = headpose_bound[i][0]
        upper_bound = headpose_bound[i][1]

        # Reverse the normalization from [-0.05, 0.05] to the original range
        original = (last_5_features[:, :, i] + 0.05) / 0.1 * (upper_bound - lower_bound) + lower_bound

        # Update the last 5 features with the original values
        last_5_features[:, :, i] = original

    # Update the normalized_motion tensor with the original last 5 features
    normalized_motion[:, :, -5:] = last_5_features

    return normalized_motion

In [10]:
# Motion tuning

motion_mean_exp_tensor = torch.from_numpy(np.load(motion_mean_exp_path)).to(device = device)
motion_mean_exp_tensor = motion_mean_exp_tensor.unsqueeze(0).to(device = device)
motion_tensor, _, _, _ = process_motion_tensor(motion_mean_exp_tensor, None, \
                            latent_type=audio_model_config['motion_latent_type'],
                            latent_mask_1=audio_model_config['latent_mask_1'],
                            latent_bound=torch.tensor(audio_model_config['latent_bound'], device=device),
                            use_headpose=True, headpose_bound=torch.tensor(headpose_bound_list, device=device))
mean_exp = torch.mean(motion_tensor.reshape(-1, motion_tensor.shape[-1]), dim=0)
print(motion_tensor.shape)
print(mean_exp.shape)
mouth_open_ratio_input = torch.tensor([mouth_ratio], device=device).unsqueeze(0)
motion_dim = audio_model_config['x_dim']

torch.Size([1, 1175, 25])
torch.Size([25])


  headpose_bound = torch.tensor(headpose_bound)
  latent_bound = torch.tensor(latent_bound)


In [11]:
# Process Input Image
ts = datetime.now()
img_rgb = load_image_rgb(input_image_path)

img_crop_256x256 = cv2.resize(img_rgb, (256, 256))  # force to resize to 256x256
I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)

shape_in = x_c_s.reshape(1, -1).to(device)
te = datetime.now()
print(f"image processing time {(te - ts).total_seconds()}s")

image processing time 0.126749s


In [12]:
# Process Input Audio
ts = datetime.now()
audio_latent, total_frame = autoregress_inference_process_wav_file(input_audio_path)
print(f"audio latent shape {audio_latent.shape}")
te = datetime.now()
print(f"audio processing time {(te - ts).total_seconds()}s")

total frame 95
audio latent shape torch.Size([5, 75, 768])
audio processing time 0.095371s


  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [13]:
def process_audio_stream(audio_latent, f_s, x_s, x_c_s, x_s_info, audio_model_config, warp_decode_func, output_buffer):
    ts = datetime.now()
    window_count = audio_latent.shape[0]
    out_motion = torch.tensor([], device=device)
    start_idx = 0

    for batch_index in range(0, window_count):

        start_time = datetime.now()

        if batch_index == 0:
            this_audio_prev = audio_latent[0:1, 0:10, :]
            audio_seq = audio_latent[0:1, 10:, :]
            this_motion_prev = torch.zeros(1, 10, motion_dim , device=device)
            gen_length = 65
        else:
            this_audio_prev = audio_latent[batch_index:batch_index+1, 0:prev_context_len, :]
            audio_seq = audio_latent[batch_index:batch_index+1, prev_context_len:, :]
            gen_length = gen_len_per_window

        mean_exp_expanded = mean_exp.expand(1, -1)

        tss = datetime.now()
        generated_motion, null_motion = inference_manager.inference(audio_seq,
                                                    shape_in, this_motion_prev, this_audio_prev, #seq_mask=seq_mask,
                                                    cfg_scale=cfg_s,
                                                    mouth_open_ratio = mouth_open_ratio_input,
                                                    denoising_steps=10,
                                                    gen_length=gen_length,
                                                    mean_exp=mean_exp_expanded)
        tee = datetime.now()

        print(f"{batch_index} inference time {(tee - tss).total_seconds()}s")
        full_window_motion = torch.cat((this_motion_prev, generated_motion), dim=1)
        this_motion_prev = full_window_motion[:, -prev_context_len:, :]

        if subtract_avg_motion:
            generated_motion = generated_motion - torch.mean(generated_motion, dim=-1, keepdim=True)

        generated_motion = (generated_motion - torch.mean(null_motion, dim=-2, keepdim=True)).squeeze(0)
        #print(f"{batch_index} generated motion {generated_motion.shape}")
        out_motion = torch.cat((out_motion, generated_motion), dim=0)
        #print(f"{batch_index} out motion {out_motion.shape}")

        end_time = datetime.now()
        print(f"{batch_index} window time {(end_time - start_time).total_seconds()}s")

    te = datetime.now()
    print(f"motion generation time {(te - ts).total_seconds()}s")

    filtered_start_time = datetime.now()

    # Then modify the filtering line to:
    out_motion_filtered = savgol_filter(out_motion[:total_frame].cpu().numpy(), window_length=5, polyorder=2, axis=0)
    out_motion_f = torch.tensor(out_motion_filtered, device=device)

    motion_gt = motion_tensor[:, :out_motion_f.shape[0], :].squeeze(0)
    out_motion_f = fixed_blink_noise(out_motion_f, motion_gt)
    out_pose_motion = out_motion_f[:, -5:]
    out_pose_smoothed = savgol_filter(out_pose_motion.cpu().numpy(), window_length=30, polyorder=2, axis=0)
    out_motion_f[:, -5:-2] *= 2
    out_motion_f[:, -2:] *= 0.5
    out_motion_f[:, -5:] = torch.tensor(out_pose_smoothed, device=device)

    full_motion_stacked = out_motion_f
    full_motion_stacked = reverse_normalize_pose(full_motion_stacked.unsqueeze(0), headpose_bound=torch.tensor(headpose_bound_list, device=device))
    full_motion = full_motion_stacked.squeeze(0)[start_idx:]
    start_idx = full_motion_stacked.shape[1]
    print(f"full motion stacked {full_motion_stacked.shape}")
    print(f"full motion {full_motion.shape}")
    print(f"next start_idx {start_idx}")

    filtered_end_time = datetime.now()
    print(f"filtering time {(filtered_end_time - filtered_start_time).total_seconds()}")

    time1 = datetime.now()
    pose = full_motion[:, -5:]
    exp = full_motion[:, :-5]

    full_63_exp = torch.zeros(full_motion.shape[0], 63, device=device)

    full_63_exp[:, audio_model_config['latent_mask_1']] = exp
    full_motion = full_63_exp.reshape(-1, 63)

    x_d_list = []
    scale = x_s_info['scale']

    for i in tqdm(range(full_motion.shape[0]), desc="Generating x_d"):
        exp = full_motion[i].reshape(21, 3)
        pitch, yaw, roll, t_x, t_y = pose[i].unsqueeze(0).unbind(-1)
        t = torch.tensor([t_x, t_y, 0], device=device)

        x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t
        x_d_list.append(x_d_i.squeeze(0))

    x_d_batch = torch.stack(x_d_list, dim=0)
    time2 = datetime.now()

    for i in tqdm(range(full_motion.shape[0]), desc="Processing batches"):
        time3 = datetime.now()
        out = warp_decode_func(f_s, x_s, x_d_batch[i])

        # Convert to numpy array
        batch_frame = out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8).squeeze(0)
        time4 = datetime.now()
        print(f"{i} {(time4 - time3).total_seconds()}s")
        output_buffer.append(batch_frame)

    time5 = datetime.now()
    print(f"{(time2 - time1).total_seconds()}s {(time5 - time2).total_seconds()}s")
    print(f"cur buffer length {len(output_buffer)}")

    #output_buffer.put(None)

def display_frames(output_buffer, pre_time):
    while True:
        # Retrieve the next frame from the buffer
        frame_tensor = output_buffer.get()

        # End of processing
        if frame_tensor is None:
            break

        # Transfer the frame from pinned memory back to CPU
        frame = frame_tensor.cpu().numpy()
        result_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        # Ensure 25 fps
        if (datetime.now() - pre_time).total_seconds() < 0.024:
            sleep(0.024 - (datetime.now() - pre_time).total_seconds())

        # Display the frame using OpenCV
        print(f"time to display {datetime.now() - pre_time}")
        cv2.imshow("Video Stream", result_bgr)
        cv2.waitKey(1)
        pre_time = datetime.now()

def stream_frames():
    # Record start time
    start_time = datetime.now()

    # Output buffer tracking
    output_buffer = Queue(maxsize = 5000)

    # Start frame processing thread
    processing_thread = Thread(target = process_audio_stream, args=(audio_latent, f_s, x_s, x_c_s, x_s_info, audio_model_config, warp_decode, output_buffer))
    processing_thread.start()

    # Start display frame
    display_frames(output_buffer, start_time)

    # Wait until terminate
    processing_thread.join()

    # Cleanup
    cv2.destroyAllWindows()


In [14]:
all_frames = []
process_audio_stream(audio_latent, f_s, x_s, x_c_s, x_s_info, audio_model_config, warp_decode, all_frames)

0 inference time 0.088378s
0 window time 0.089378s
1 inference time 0.061624s
1 window time 0.061624s
2 inference time 0.061601s
2 window time 0.061601s
3 inference time 0.063757s
3 window time 0.063757s


  headpose_bound = torch.tensor(headpose_bound)


4 inference time 0.063608s
4 window time 0.063608s
motion generation time 0.339968s
full motion stacked torch.Size([1, 95, 25])
full motion torch.Size([95, 25])
next start_idx 95
filtering time 0.044666


Generating x_d: 100%|██████████| 95/95 [00:00<00:00, 1909.00it/s]
Processing batches:  11%|█         | 10/95 [00:00<00:02, 30.90it/s]

0 0.156918s
1 0.019035s
2 0.024267s
3 0.023607s
4 0.025154s
5 0.022621s
6 0.023965s
7 0.024035s
8 0.026109s
9 0.026129s


Processing batches:  16%|█▌        | 15/95 [00:00<00:02, 36.09it/s]

10 0.023707s
11 0.02069s
12 0.023403s
13 0.021454s
14 0.021871s
15 0.024161s
16 0.021328s
17 0.021175s
18 0.021648s


Processing batches:  26%|██▋       | 25/95 [00:00<00:01, 39.88it/s]

19 0.037255s
20 0.020209s
21 0.02318s
22 0.021382s
23 0.022178s
24 0.02364s
25 0.021542s
26 0.023077s
27 0.021528s
28 0.021422s


Processing batches:  37%|███▋      | 35/95 [00:00<00:01, 42.42it/s]

29 0.02309s
30 0.023085s
31 0.023015s
32 0.023335s
33 0.018695s
34 0.022673s
35 0.023123s
36 0.022312s
37 0.02366s
38 0.022661s


Processing batches:  47%|████▋     | 45/95 [00:01<00:01, 43.46it/s]

39 0.021407s
40 0.02271s
41 0.022148s
42 0.022637s
43 0.021082s
44 0.021116s
45 0.019082s
46 0.022186s
47 0.023324s
48 0.020665s


Processing batches:  58%|█████▊    | 55/95 [00:01<00:00, 42.64it/s]

49 0.024591s
50 0.02506s
51 0.019679s
52 0.022247s
53 0.035541s
54 0.02265s
55 0.021187s
56 0.020695s
57 0.023988s


Processing batches:  68%|██████▊   | 65/95 [00:01<00:00, 43.71it/s]

58 0.022912s
59 0.021336s
60 0.023109s
61 0.020684s
62 0.021172s
63 0.02368s
64 0.023121s
65 0.019965s
66 0.022621s
67 0.021046s


Processing batches:  79%|███████▉  | 75/95 [00:01<00:00, 44.51it/s]

68 0.023271s
69 0.02339s
70 0.022313s
71 0.022408s
72 0.023099s
73 0.018453s
74 0.022646s
75 0.02226s
76 0.023053s
77 0.021881s


Processing batches:  89%|████████▉ | 85/95 [00:02<00:00, 44.53it/s]

78 0.022199s
79 0.022043s
80 0.023209s
81 0.022602s
82 0.023989s
83 0.020065s
84 0.022526s
85 0.020772s
86 0.029686s


Processing batches: 100%|██████████| 95/95 [00:02<00:00, 41.05it/s]

87 0.030107s
88 0.021682s
89 0.022131s
90 0.021045s
91 0.025122s
92 0.020011s
93 0.024974s
94 0.018017s
0.052766s 2.315736s
cur buffer length 95





In [15]:
def write_video(all_frames_in, audio_path, output_path):
    all_frames = []
    for frame in all_frames_in:
        all_frames.append(frame.cpu().numpy())
    output_no_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test_no_audio.mp4'
    if os.path.exists(output_no_audio_path):
        os.remove(output_no_audio_path)
    if os.path.exists(output_path):
        os.remove(output_path)
    fps = 25  # Adjust as needed

    height, width, layers = all_frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(output_no_audio_path, fourcc, fps, (width, height))

    for frame in all_frames:
        video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    video.release()

    # Add audio to the video using ffmpeg
    input_video = output_no_audio_path
    input_audio = audio_path  # Use the path to your audio file

    ffmpeg_cmd = [
        'ffmpeg',
        '-i', input_video,
        '-i', input_audio,
        '-c:v', 'copy',
        '-c:a', 'aac',
        '-shortest',
        output_path
    ]

    try:
        subprocess.run(ffmpeg_cmd, check=True)
        os.remove(output_no_audio_path)
        print(f"Video with audio saved to {output_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error adding audio to video: {e}")

In [16]:
write_video(all_frames, input_audio_path, output_video)

Video with audio saved to D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/5_3_full_10_65_with_audio.mp4
