In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
import json
import yaml
import glob
import tyro
import time
import imageio
import subprocess
from tqdm import tqdm
from typing import List, Tuple
from rich.progress import track
from threading import Thread
from queue import Queue
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import cv2
import torch
import torchaudio
import torchvision
import torchvision.transforms as transforms
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import Wav2Vec2Model, Wav2Vec2Processor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Live Portrait
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.utils.cropper import Cropper
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth

# DiT
from audio_dit.inference import InferenceManager, get_model
from audio_dit.dataset import load_and_process_pair

In [4]:
# Audio model
MODEL_NAME = "facebook/wav2vec2-base-960h"
TARGET_SAMPLE_RATE = 16000
FRAME_RATE = 25
SECTION_LENGTH = 3
OVERLAP = 10

DB_ROOT = 'vox2-audio-tx'
LOG = 'log'
AUDIO = 'audio/audio'
OUTPUT_DIR = 'audio_encoder_output'
BATCH_SIZE = 16

# DiT model
config_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/config.json'
weight_path = 'D:/Projects/Upenn_CIS_5650/final-project/config/model.pth'

# input
input_image_path = 'D:/Projects/Upenn_CIS_5650/final-project/data/img/test4.jpg'
input_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/data/audio/test3.wav'

# output
output_no_audio_path = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test5_no_audio.mp4'
output_video = 'D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test5_with_audio.mp4'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
# Live Portrait Pipeline

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
print("Compile complete")

'''
Util functions
'''

def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

def euler_to_quaternion(pitch, yaw, roll):
    cy = torch.cos(yaw * 0.5)
    sy = torch.sin(yaw * 0.5)
    cp = torch.cos(pitch * 0.5)
    sp = torch.sin(pitch * 0.5)
    cr = torch.cos(roll * 0.5)
    sr = torch.sin(roll * 0.5)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy

    return torch.stack([w, x, y, z], dim=-1)

def quaternion_to_euler(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in radians.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor
    """
    # Extract the values from q
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    pitch = torch.where(
        torch.abs(sinp) >= 1,
        torch.sign(sinp) * torch.pi / 2,
        torch.asin(sinp)
    )

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return pitch, yaw, roll

def quaternion_to_euler_degrees(q):
    """
    Convert quaternion to Euler angles (pitch, yaw, roll) in degrees.
    q: torch.Tensor of shape (..., 4) representing quaternions (w, x, y, z)
    Returns: tuple of (pitch, yaw, roll) as torch.Tensor in degrees
    """
    pitch, yaw, roll = quaternion_to_euler(q)
    return torch.rad2deg(pitch), torch.rad2deg(yaw), torch.rad2deg(roll)


'''
Loading source related modules
'''
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out
'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

Compile complete


In [7]:
# device = torch.device('cuda')  # Replace 'cuda' with 'cpu' if checking for CPU
# model_on_device = all(param.device == device for param in model.parameters())
# print(f"Is the model on {device}? {model_on_device}")

for name, param in warping_module.named_parameters():
    print(f"Parameter {name} is on device: {param.device}")

Parameter dense_motion_network.hourglass.encoder.down_blocks.0.conv.weight is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.0.conv.bias is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.0.norm.weight is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.0.norm.bias is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.1.conv.weight is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.1.conv.bias is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.1.norm.weight is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.1.norm.bias is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.2.conv.weight is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.down_blocks.2.conv.bias is on device: cuda:0
Parameter dense_motion_network.hourglass.encoder.d

In [8]:
for name, param in spade_generator.named_parameters():
    print(f"Parameter {name} is on device: {param.device}")

Parameter fc.weight is on device: cuda:0
Parameter fc.bias is on device: cuda:0
Parameter G_middle_0.conv_0.bias is on device: cuda:0
Parameter G_middle_0.conv_0.weight_orig is on device: cuda:0
Parameter G_middle_0.conv_1.bias is on device: cuda:0
Parameter G_middle_0.conv_1.weight_orig is on device: cuda:0
Parameter G_middle_0.norm_0.mlp_shared.0.weight is on device: cuda:0
Parameter G_middle_0.norm_0.mlp_shared.0.bias is on device: cuda:0
Parameter G_middle_0.norm_0.mlp_gamma.weight is on device: cuda:0
Parameter G_middle_0.norm_0.mlp_gamma.bias is on device: cuda:0
Parameter G_middle_0.norm_0.mlp_beta.weight is on device: cuda:0
Parameter G_middle_0.norm_0.mlp_beta.bias is on device: cuda:0
Parameter G_middle_0.norm_1.mlp_shared.0.weight is on device: cuda:0
Parameter G_middle_0.norm_1.mlp_shared.0.bias is on device: cuda:0
Parameter G_middle_0.norm_1.mlp_gamma.weight is on device: cuda:0
Parameter G_middle_0.norm_1.mlp_gamma.bias is on device: cuda:0
Parameter G_middle_0.norm_1.ml

In [7]:
# Audio Pipeline

# Move model and processor to global scope
wav2vec_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

def read_multiple_audios(paths, num_threads=12):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(load_and_process_audio, paths))
    return results


def read_json_and_form_paths(data,id_key):
    filenames=[]
    file_paths = []

    # Iterate through the nested structure
    for id_key, id_value in data.items():
        os.makedirs(os.path.join(DB_ROOT,OUTPUT_DIR,id_key), exist_ok=True)
        for url_key, url_value in id_value.items():
            for clip_id in url_value.keys():
                # Form the file path
                file_path = os.path.join(DB_ROOT,AUDIO,id_key, url_key, clip_id.replace('.txt', '.wav'))
                file_name = os.path.join(DB_ROOT,OUTPUT_DIR,id_key, url_key+'+'+clip_id.replace('.txt', ''))
                filenames.append(file_name)
                file_paths.append(file_path)

    return file_paths, filenames

def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)

    original_sample_rate = sample_rate

    if sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    print(file_path," waveform.shape ",waveform.shape)

    # Calculate section length and overlap in samples
    section_samples = SECTION_LENGTH * 16027
    overlap_samples = int(OVERLAP / FRAME_RATE * TARGET_SAMPLE_RATE)
    print('section_samples',section_samples,'overlap_samples',overlap_samples)

    # Pad if shorter than 3 seconds
    if waveform.shape[1] < section_samples:
        waveform = torch.nn.functional.pad(waveform, (0, section_samples - waveform.shape[1]))
        return [waveform.squeeze(0)], original_sample_rate

    # Split into sections with overlap
    sections = []
    start = 0

    print('starting to segment', file_path)
    while start < waveform.shape[1]:
        end = start + section_samples
        if end >= waveform.shape[1]:
            tmp=waveform[:, start:min(end, waveform.shape[1])]
            tmp = torch.nn.functional.pad(tmp, (0, section_samples - tmp.shape[1]))
            sections.append(tmp.squeeze(0))
            print(tmp.shape)
            break
        else:
            sections.append(waveform[:, start:min(end, waveform.shape[1])].squeeze(0))

        start = int(end - overlap_samples)


    return file_path, sections

def inference_process_wav_file(path):
    audio_path, segments = load_and_process_audio(path)
    print(audio_path,segments)
    segments = np.array(segments)

    inputs = wav2vec_processor(segments, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated


def process_wav_file(paths, output_paths, uid):
    device = torch.device(f"cuda")

    model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
    processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

    read_2_gpu_batch_size = 2048
    gpu_batch_size = BATCH_SIZE
    process_queue = torch.Tensor().to(device)

    audio_segments = read_multiple_audios(paths, num_threads=4)
    all_segments = []
    total_segments = 0

    audio_lengths = []
    output_fns = []

    for (audio_path, segments), output_fn in zip(audio_segments,output_paths):
        all_segments.extend(segments)
        segment_count = len(segments)
        total_segments += segment_count
        audio_lengths.append(segment_count)

        output_fns.append(output_fn)

    all_segments = np.array(all_segments)
    print(all_segments.size)

    read_data_2_gpu_pointer = 0
    pbar = tqdm(total=total_segments, desc=f"Processing {uid}")

    while read_data_2_gpu_pointer < total_segments:
        current_batch_size = min(read_2_gpu_batch_size, total_segments - read_data_2_gpu_pointer)

        batch_input = all_segments[read_data_2_gpu_pointer:read_data_2_gpu_pointer + current_batch_size]

        mini_batch_start = 0
        all_info = []
        while mini_batch_start < batch_input.shape[0]:
            mini_batch_end = min(mini_batch_start + gpu_batch_size, batch_input.shape[0])
            mini_batch = batch_input[mini_batch_start:mini_batch_end]

            inputs = processor(mini_batch, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

            with torch.no_grad():
                outputs = model(inputs)

            latent = outputs.last_hidden_state
            print('latent',latent.shape)
            seq_length = latent.shape[1]
            new_seq_length = int(seq_length * (FRAME_RATE / 50))  # Assuming Wav2Vec2 outputs at ~50Hz

            latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                            size=new_seq_length,
                                                            mode='linear',
                                                            align_corners=True).transpose(1,2)
            print('latent_features_interpolated',latent_features_interpolated.shape)
            all_info.append(latent_features_interpolated)

            mini_batch_start = mini_batch_end
        all_info_tensor = torch.cat(all_info, dim=0)

        process_queue = torch.cat((process_queue, all_info_tensor), dim=0)

        print(audio_lengths)
        while len(output_fns) > 0 and len(process_queue) >= audio_lengths[0]:
            current_output_fn = output_fns[0]
            current_segment_count = audio_lengths[0]

            audio_tensor = process_queue[:current_segment_count]
            np.save(current_output_fn, audio_tensor.cpu().numpy())
            print('save',current_output_fn)
            process_queue = process_queue[current_segment_count:]
            output_fns.pop(0)
            audio_lengths.pop(0)


        read_data_2_gpu_pointer += current_batch_size
        pbar.update(current_batch_size)

    pbar.close()

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# DiT Model

print("config_path exists:", os.path.exists(config_path))
audio_model_config = json.load(open(config_path))
inference_manager = get_model(config_path, weight_path, device)

config_path exists: True


RuntimeError: Error(s) in loading state_dict for DiffLiveHead:
	Unexpected key(s) in state_dict: "start_audio_feat", "start_motion_feat". 

In [11]:
# Process Input Image

img_rgb = load_image_rgb(input_image_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)

print(f"I_s shape {I_s.shape}")
print(f"x_c_s shape {x_c_s.shape}")
print(f"x_s shape {x_s.shape}")
print(f"f_s shape {f_s.shape}")

I_s shape torch.Size([1, 3, 256, 256])
x_c_s shape torch.Size([1, 21, 3])
x_s shape torch.Size([1, 21, 3])
f_s shape torch.Size([1, 32, 16, 64, 64])


In [12]:
# Process Input Audio

torchaudio.set_audio_backend("soundfile")
custom_audio_latent = inference_process_wav_file(input_audio_path)

custom_audio_latent.shape

D:/Projects/Upenn_CIS_5650/final-project/data/audio/test3.wav  waveform.shape  torch.Size([1, 60338])
section_samples 48081 overlap_samples 6400
starting to segment D:/Projects/Upenn_CIS_5650/final-project/data/audio/test3.wav
torch.Size([1, 48081])
D:/Projects/Upenn_CIS_5650/final-project/data/audio/test3.wav [tensor([ 0.0000,  0.0000,  0.0000,  ..., -0.0016, -0.0016, -0.0016]), tensor([-0.0778, -0.0795, -0.0818,  ...,  0.0000,  0.0000,  0.0000])]


  torchaudio.set_audio_backend("soundfile")
  attn_output = torch.nn.functional.scaled_dot_product_attention(


torch.Size([2, 75, 768])

In [15]:
# Generate Motion

used_audio_example = input_audio_path
audio_latent = custom_audio_latent

audio_latent_input = audio_latent
latent_mask_used=audio_model_config['latent_mask_1']
latent_bound = audio_model_config['latent_bound']
print(latent_mask_used, latent_bound)

audio_seq = audio_latent_input[:, 10:, :]
audio_prev = audio_latent_input[:, :10, :]

print("Audio input shape:", audio_seq.shape)
print("Audio previous shape:", audio_prev.shape)

motion_prev = torch.zeros(audio_latent.shape[0], 10, 6, device=device)

mouth_open_ratio_val = 0.25
mouth_open_ratio_input = torch.tensor([mouth_open_ratio_val], device=device).unsqueeze(0)
out_motion = torch.tensor([], device=device)
B, T, audio_dim = audio_seq.shape
motion_dim = audio_model_config['x_dim']
shape_in = x_c_s.reshape(1, -1).to(device)
this_audio_prev = torch.zeros(1, 10, audio_dim, device=device)
this_motion_prev = torch.zeros(1, 10, motion_dim , device=device)
motion_prev = torch.zeros(1, 10, motion_dim , device=device)
#print("Audio input shape:", audio_seq.shape)
#print(out_motion.shape)
for batch_index in range(0, audio_seq.shape[0]):
    generated_motion = inference_manager.inference(audio_seq[batch_index:batch_index+1],
                                                shape_in, this_motion_prev, this_audio_prev,
                                                cfg_scale=0.25,
                                                mouth_open_ratio = mouth_open_ratio_input,
                                                denoising_steps=10)
    this_motion_prev = generated_motion[:, -10:, :]
    this_audio_prev = audio_seq[batch_index:batch_index+1, -10:, :]

    generated_motion = generated_motion - torch.mean(generated_motion, dim=-1, keepdim=True)
    out_motion = torch.cat((out_motion, generated_motion), dim=0)
    print(generated_motion.shape)
    #print(out_motion.shape)

generated_motion = out_motion
generated_motion.shape, motion_prev.shape

[4, 6, 7, 22, 33, 34, 40, 43, 45, 46, 48, 51, 52, 53, 57, 58, 59, 60, 61, 62] [-0.05029296875, 0.0857086181640625, -0.07587742805480957, 0.058624267578125, -0.0004341602325439453, 0.00019466876983642578, -0.038482666015625, 0.0345458984375, -0.030120849609375, 0.038360595703125, -3.0279159545898438e-05, 1.3887882232666016e-05, -0.0364990234375, 0.036102294921875, -0.043212890625, 0.046844482421875, -4.3332576751708984e-05, 1.8775463104248047e-05, -0.03326416015625, 0.057373046875, -0.03460693359375, 0.031707763671875, -0.0001958608627319336, 0.0005192756652832031, -0.0728759765625, 0.0587158203125, -0.04840087890625, 0.039642333984375, -0.00025916099548339844, 0.00048089027404785156, -0.09722900390625, 0.12469482421875, -0.1556396484375, 0.09326171875, -0.00018024444580078125, 0.00037860870361328125, -0.0279384758323431, 0.010650634765625, -0.039306640625, 0.03802490234375, -1.049041748046875e-05, 3.6954879760742188e-06, -0.032989501953125, 0.044281005859375, -0.037261962890625, 0.0433

(torch.Size([2, 65, 20]), torch.Size([1, 10, 20]))

In [12]:
# Generate Frames

# Assuming generated_motion is your output from dit_inference
# generated_motion shape: [2, 65, 63]
# motion_prev shape: [2, 10, 63]

def process_motion_batch(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()
    frames = []
    B, T, feat_count = gen_motion_batch.shape
    full_motion = gen_motion_batch.reshape(B*T, feat_count)
    full_motion = torch.cat([motion_prev[0], full_motion], dim=0)

    pose = full_motion[:, -5:]
    exp = full_motion[:, :]
    print("pose", pose.shape, "exp", exp.shape)
    # exp_mask = torch.zeros_like(full_motion[0][0])
    # pos_mouth = [14, 17, 19, 20]
    # eye and mouth [15, 16, 18]
    # pos_eye & forehead = [1, 2, 11, 12, 13] # 1.z is mouth, 12 small eye
    # shape [0, 3, 4, 7, 8, 9, 10] # may include shape dependent pose
    # pos_cloth = [5, 6]

    # for p in pos_mouth:
    #     exp_mask[p * 3:(p + 1) * 3] = 1

    # # exp_mask = exp_mask.reshape(21,3)
    # print(full_motion[0, 10:15, :] * exp_mask)

    t_identity = torch.zeros((1, 3), dtype=torch.float32, device=device)
    pitch_identity = torch.zeros((1), dtype=torch.float32, device=device)
    yaw_identity = torch.zeros((1), dtype=torch.float32, device=device)
    roll_identity = torch.zeros((1), dtype=torch.float32, device=device)
    scale_identity = torch.ones((1), dtype=torch.float32, device=device) * 1.5

    use_identity_pose = True
    if use_identity_pose:
        t_s = t_identity
        pitch_s = pitch_identity
        yaw_s = yaw_identity
        roll_s = roll_identity
        scale_s = scale_identity
    else:
        t_s = x_s_info['t']
        pitch_s = x_s_info['pitch']
        yaw_s = x_s_info['yaw']
        roll_s = x_s_info['roll']
        scale_s = x_s_info['scale']
    t = t_s
    pitch = pitch_s
    yaw = yaw_s
    roll = roll_s
    scale = scale_s

    full_63_exp = torch.zeros(full_motion.shape[0], 63, device=device)
    for i, dim in enumerate(audio_model_config['latent_mask_1']):
        print(i, dim)
        full_63_exp[:, dim] = exp[:, i]
    full_motion = full_63_exp.reshape(-1, 63)

    x_d_list = []

    for i in tqdm(range(full_motion.shape[0]), desc="Generating x_d"):
        motion = full_motion[i].reshape(21, 3)

        # Initialize empty tensors

        # # Extract values from motion
        exp = motion #* exp_mask
        # pitch = pose[i, 0:1]
        # yaw = pose[i, 1:2]
        # roll = pose[i, 2:3]
        # t_x = pose[i, 3:4]
        # t_y = pose[i, 4:5]
        t = torch.tensor(t, device=device)

        x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t
        x_d_list.append(x_d_i.squeeze(0))

    x_d_batch = torch.stack(x_d_list, dim=0)
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size

    middle_time = datetime.now()

    frames = []
    for i in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])

        batch_f_s = f_s_batch[start_idx:end_idx]
        batch_x_s = x_s_batch[start_idx:end_idx]
        batch_x_d = x_d_batch[start_idx:end_idx]

        out = warp_decode_func(batch_f_s, batch_x_s, batch_x_d)

        # Convert to numpy array
        batch_frames = (out['out'].permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        frames.extend(list(batch_frames))

    end_time = datetime.now()

    print(f"time to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {end_time - middle_time}s")
    print(f"total generation time {end_time - start_time}")
    return frames

In [13]:
print("Warm up...\n")
all_frames = process_motion_batch(generated_motion, motion_prev, f_s, x_s, warp_decode)

Warm up...

pose torch.Size([140, 5]) exp torch.Size([140, 20])
0 4
1 6
2 7
3 22
4 33
5 34
6 40
7 43
8 45
9 46
10 48
11 51
12 52
13 53
14 57
15 58
16 59
17 60
18 61
19 62


  t = torch.tensor(t, device=device)
Generating x_d: 100%|██████████| 140/140 [00:00<00:00, 2057.40it/s]
Processing batches: 100%|██████████| 35/35 [00:03<00:00, 10.29it/s]

time to prep x_d 0:00:00.071558s
time to warp and decode 0:00:03.402035s
total generation time 0:00:03.473593





In [47]:
print("Without profiiling...\n")
all_frames = process_motion_batch(generated_motion, motion_prev, f_s, x_s, warp_decode)

Without profiiling...

pose torch.Size([140, 5]) exp torch.Size([140, 20])
0 4
1 6
2 7
3 22
4 33
5 34
6 40
7 43
8 45
9 46
10 48
11 51
12 52
13 53
14 57
15 58
16 59
17 60
18 61
19 62


  t = torch.tensor(t, device=device)
Generating x_d: 100%|██████████| 140/140 [00:00<00:00, 2058.09it/s]
Processing batches: 100%|██████████| 35/35 [00:03<00:00, 10.68it/s]


In [48]:
print("Profiling...\n")
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
with profile(activities = activities, profile_memory = True, record_shapes = True, with_stack = True, with_modules = True, with_flops = True) as prof:
    with record_function("full_model_inference"):
        all_frames = process_motion_batch(generated_motion, motion_prev, f_s, x_s, warp_decode)

Profiling...

pose torch.Size([140, 5]) exp torch.Size([140, 20])
0 4
1 6
2 7
3 22
4 33
5 34
6 40
7 43
8 45
9 46
10 48
11 51
12 52
13 53
14 57
15 58
16 59
17 60
18 61
19 62


  t = torch.tensor(t, device=device)
Generating x_d: 100%|██████████| 140/140 [00:00<00:00, 472.03it/s]
Processing batches: 100%|██████████| 35/35 [00:04<00:00,  7.51it/s]


In [49]:
sort_by_keyword = 'self_cuda_time_total'
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=20))

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total KFLOPs  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
         aten::cudnn_convolution         3.64%     180.944ms         3.64%     180.944ms      57.443us        1.560s        31.37%        1.560s     495.187us           0 b           0 b      91.15 Gb      91.15 Gb          3150            --  
            full_m

In [50]:
sort_by_keyword = 'self_cpu_time_total'
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=20))

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total KFLOPs  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
            full_model_inference        32.36%        1.609s       100.00%        4.973s        4.973s        1.219s        24.52%        4.973s        4.973s           0 b    -245.17 Mb           0 b    -295.89 Gb             1            --  
                  

In [33]:
# 40s audio
print(f"generate motion feature shape: {generated_motion.shape}")
print(f"previous motion feature shape: {motion_prev.shape}")
print(f"f_s feature shape: {f_s.shape}")
print(f"x_s feature shape: {x_s.shape}")

generate motion feature shape: torch.Size([16, 65, 20])
previous motion feature shape: torch.Size([1, 10, 20])
f_s feature shape: torch.Size([1, 32, 16, 64, 64])
x_s feature shape: torch.Size([1, 21, 3])


In [55]:
# 3s audio
print(f"generate motion feature shape: {generated_motion.shape}")
print(f"previous motion feature shape: {motion_prev.shape}")
print(f"f_s feature shape: {f_s.shape}")
print(f"x_s feature shape: {x_s.shape}")

generate motion feature shape: torch.Size([2, 65, 20])
previous motion feature shape: torch.Size([1, 10, 20])
f_s feature shape: torch.Size([1, 32, 16, 64, 64])
x_s feature shape: torch.Size([1, 21, 3])


In [26]:
# Generate Frames Optimized with torch parallelized matrix operations

def process_motion_batch2(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size
    frames = []

    middle_time = datetime.now()
    t1 = 0
    t2 = 0

    for i in tqdm(range(num_batches), desc="Processing batches"):
        i1 = datetime.now()
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])
        i2 = datetime.now()
        out = warp_decode_func(f_s_batch[start_idx:end_idx], x_s_batch[start_idx:end_idx], x_d_batch[start_idx:end_idx])
        i3 = datetime.now()
        # Convert to numpy array
        batch_frames = (out['out'].permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        frames.extend(list(batch_frames))
        i4 = datetime.now()
        print(f"{i4 - i1}: {i2 - i1} / {i3 - i2} / {i4 - i3}")
        t1 += (i3 - i2).total_seconds()
        t2 += (i4 - i3).total_seconds()
    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {end_time - middle_time}s")
    print(f"total generation time {end_time - start_time}")
    print(f"{t1} {t2}")
    return frames

In [27]:
all_frames = process_motion_batch2(generated_motion, motion_prev, f_s, x_s, warp_decode)


generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches:   6%|▌         | 2/35 [00:00<00:03, 10.51it/s]

0:00:00.093062: 0:00:00 / 0:00:00.017556 / 0:00:00.075506
0:00:00.097200: 0:00:00 / 0:00:00.015651 / 0:00:00.081549
0:00:00.095585: 0:00:00 / 0:00:00.016017 / 0:00:00.079568


Processing batches:  11%|█▏        | 4/35 [00:00<00:02, 10.52it/s]

0:00:00.093537: 0:00:00 / 0:00:00.018015 / 0:00:00.075522
0:00:00.090629: 0:00:00 / 0:00:00.019112 / 0:00:00.071517


Processing batches:  17%|█▋        | 6/35 [00:00<00:02, 10.76it/s]

0:00:00.090044: 0:00:00 / 0:00:00.016018 / 0:00:00.074026
0:00:00.090803: 0:00:00 / 0:00:00.015237 / 0:00:00.075566
0:00:00.097346: 0:00:00 / 0:00:00.015035 / 0:00:00.082311


Processing batches:  23%|██▎       | 8/35 [00:00<00:02, 10.69it/s]

0:00:00.097585: 0:00:00 / 0:00:00.018508 / 0:00:00.079077


Processing batches:  29%|██▊       | 10/35 [00:00<00:02, 10.63it/s]

0:00:00.092244: 0:00:00 / 0:00:00.017505 / 0:00:00.074739
0:00:00.090588: 0:00:00 / 0:00:00.016064 / 0:00:00.074524
0:00:00.096046: 0:00:00 / 0:00:00.015570 / 0:00:00.080476

Processing batches:  40%|████      | 14/35 [00:01<00:01, 10.76it/s]


0:00:00.090562: 0:00:00 / 0:00:00.015505 / 0:00:00.075057
0:00:00.091057: 0:00:00 / 0:00:00.016026 / 0:00:00.075031


Processing batches:  46%|████▌     | 16/35 [00:01<00:01, 10.78it/s]

0:00:00.093098: 0:00:00 / 0:00:00.016506 / 0:00:00.076592
0:00:00.091634: 0:00:00 / 0:00:00.016509 / 0:00:00.075125
0:00:00.090044: 0:00:00 / 0:00:00.016013 / 0:00:00.074031


Processing batches:  57%|█████▋    | 20/35 [00:01<00:01, 10.71it/s]

0:00:00.093614: 0:00:00 / 0:00:00.015506 / 0:00:00.078108
0:00:00.098101: 0:00:00 / 0:00:00.017517 / 0:00:00.080584
0:00:00.092676: 0:00:00 / 0:00:00.017016 / 0:00:00.075660


Processing batches:  63%|██████▎   | 22/35 [00:02<00:01, 10.74it/s]

0:00:00.092567: 0:00:00 / 0:00:00.018016 / 0:00:00.074551
0:00:00.092537: 0:00:00 / 0:00:00.016510 / 0:00:00.076027
0:00:00.093036: 0:00:00 / 0:00:00.015500 / 0:00:00.077536


Processing batches:  74%|███████▍  | 26/35 [00:02<00:00, 10.70it/s]

0:00:00.097230: 0:00:00 / 0:00:00.017511 / 0:00:00.079719
0:00:00.092582: 0:00:00 / 0:00:00.015047 / 0:00:00.077535
0:00:00.091598: 0:00:00 / 0:00:00.018511 / 0:00:00.073087


Processing batches:  80%|████████  | 28/35 [00:02<00:00, 10.60it/s]

0:00:00.094096: 0:00:00 / 0:00:00.017012 / 0:00:00.077084
0:00:00.098565: 0:00:00 / 0:00:00.017043 / 0:00:00.081522
0:00:00.090093: 0:00:00 / 0:00:00.015056 / 0:00:00.075037


Processing batches:  86%|████████▌ | 30/35 [00:02<00:00, 10.74it/s]

0:00:00.089565: 0:00:00 / 0:00:00.015539 / 0:00:00.074026
0:00:00.093544: 0:00:00 / 0:00:00.016505 / 0:00:00.077039
0:00:00.089670: 0:00:00.001000 / 0:00:00.015082 / 0:00:00.073588


Processing batches: 100%|██████████| 35/35 [00:03<00:00, 10.70it/s]

0:00:00.094557: 0:00:00 / 0:00:00.015506 / 0:00:00.079051
0:00:00.091057: 0:00:00 / 0:00:00.015506 / 0:00:00.075551
0:00:00.094598: 0:00:00 / 0:00:00.014510 / 0:00:00.080088

time to prep x_d 0:00:00.002510s
time to warp and decode 0:00:03.270764s
total generation time 0:00:03.273274
0.5737400000000001 2.6860100000000005





In [32]:
# Generate Frames Optimized with fast pinned memeory transfer

def process_motion_batch3(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size
    output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)

    middle_time = datetime.now()

    t1 = 0
    t2 = 0

    for i in tqdm(range(num_batches), desc="Processing batches"):
        i1 = datetime.now()
        # Step 1 calculate starting and ending index for current batch
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])
        i2 = datetime.now()
        # Step 2 process current batch through the warp_decode_func
        out = warp_decode_func(f_s_batch[start_idx:end_idx], x_s_batch[start_idx:end_idx], x_d_batch[start_idx:end_idx])
        i3 = datetime.now()
        # Step 3 convert to numpy array and store it in output
        output_buffer[start_idx:end_idx].copy_(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8), non_blocking = False)
        i4 = datetime.now()
        print(f"{i4 - i1}: {i2 - i1} / {i3 - i2} / {i4 - i3}")
        t1 += (i3 - i2).total_seconds()
        t2 += (i4 - i3).total_seconds()

    next_time = datetime.now()

    frames = list(output_buffer.cpu().numpy())
    del output_buffer

    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {next_time - middle_time}s")
    print(f"time to transfer pinned to cpu {end_time - next_time}")
    print(f"total generation time {end_time - start_time}")
    print(f"{t1} {t2}")

    return frames

In [33]:
all_frames = process_motion_batch3(generated_motion, motion_prev, f_s, x_s, warp_decode)

generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches:   6%|▌         | 2/35 [00:00<00:02, 14.57it/s]

0:00:00.069265: 0:00:00 / 0:00:00.016504 / 0:00:00.052761
0:00:00.068031: 0:00:00 / 0:00:00.019002 / 0:00:00.049029
0:00:00.063028: 0:00:00 / 0:00:00.016004 / 0:00:00.047024
0:00:00.064533: 0:00:00 / 0:00:00.016509 / 0:00:00.048024

Processing batches:  17%|█▋        | 6/35 [00:00<00:01, 14.89it/s]


0:00:00.070521: 0:00:00 / 0:00:00.015001 / 0:00:00.055520
0:00:00.064419: 0:00:00 / 0:00:00.017029 / 0:00:00.047390
0:00:00.063541: 0:00:00 / 0:00:00.017010 / 0:00:00.046531


Processing batches:  29%|██▊       | 10/35 [00:00<00:01, 15.33it/s]

0:00:00.064527: 0:00:00 / 0:00:00.016506 / 0:00:00.048021
0:00:00.064531: 0:00:00 / 0:00:00.017013 / 0:00:00.047518
0:00:00.063525: 0:00:00 / 0:00:00.016508 / 0:00:00.047017
0:00:00.063532: 0:00:00 / 0:00:00.017006 / 0:00:00.046526


Processing batches:  40%|████      | 14/35 [00:00<00:01, 15.38it/s]

0:00:00.063593: 0:00:00 / 0:00:00.019007 / 0:00:00.044586
0:00:00.065529: 0:00:00 / 0:00:00.018509 / 0:00:00.047020
0:00:00.064102: 0:00:00 / 0:00:00.015508 / 0:00:00.048594
0:00:00.064027: 0:00:00 / 0:00:00.017002 / 0:00:00.047025


Processing batches:  51%|█████▏    | 18/35 [00:01<00:01, 15.39it/s]

0:00:00.065039: 0:00:00 / 0:00:00.018506 / 0:00:00.046533
0:00:00.064585: 0:00:00 / 0:00:00.017537 / 0:00:00.047048
0:00:00.064524: 0:00:00 / 0:00:00.016511 / 0:00:00.048013
0:00:00.063656: 0:00:00 / 0:00:00.015108 / 0:00:00.048548


Processing batches:  63%|██████▎   | 22/35 [00:01<00:00, 15.42it/s]

0:00:00.064527: 0:00:00 / 0:00:00.016512 / 0:00:00.048015
0:00:00.065028: 0:00:00 / 0:00:00.018513 / 0:00:00.046515
0:00:00.064019: 0:00:00 / 0:00:00.016506 / 0:00:00.047513
0:00:00.063553: 0:00:00 / 0:00:00.015035 / 0:00:00.048518


Processing batches:  74%|███████▍  | 26/35 [00:01<00:00, 15.46it/s]

0:00:00.064565: 0:00:00 / 0:00:00.018538 / 0:00:00.046027
0:00:00.063547: 0:00:00 / 0:00:00.017003 / 0:00:00.046544
0:00:00.064525: 0:00:00 / 0:00:00.017505 / 0:00:00.047020


Processing batches:  86%|████████▌ | 30/35 [00:01<00:00, 15.49it/s]

0:00:00.064029: 0:00:00 / 0:00:00.015509 / 0:00:00.048520
0:00:00.063517: 0:00:00 / 0:00:00.016003 / 0:00:00.047514
0:00:00.064523: 0:00:00 / 0:00:00.016001 / 0:00:00.048522
0:00:00.063530: 0:00:00 / 0:00:00.017007 / 0:00:00.046523


Processing batches:  91%|█████████▏| 32/35 [00:02<00:00, 15.49it/s]

0:00:00.065049: 0:00:00 / 0:00:00.017021 / 0:00:00.048028
0:00:00.064084: 0:00:00 / 0:00:00.015003 / 0:00:00.049081
0:00:00.065027: 0:00:00.001000 / 0:00:00.015011 / 0:00:00.049016
0:00:00.064041: 0:00:00 / 0:00:00.017028 / 0:00:00.047013


Processing batches: 100%|██████████| 35/35 [00:02<00:00, 15.38it/s]

0:00:00.064033: 0:00:00 / 0:00:00.017516 / 0:00:00.046517

time to prep x_d 0:00:00.003030s
time to warp and decode 0:00:02.278204s
time to transfer pinned to cpu 0:00:00
total generation time 0:00:02.281234
0.587491 1.6736139999999997





In [74]:
# Generate Frames Optimized with multiple streams attemp 1

def process_motion_batch4(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size
    output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)
    streams = [torch.cuda.Stream(), torch.cuda.Stream()]

    middle_time = datetime.now()

    for i in tqdm(range(num_batches), desc="Processing batches"):
        i1 = datetime.now()
        # Step 1 calculate starting and ending index for current batch
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])

        with torch.cuda.stream(streams[i % 2]):
            i2 = datetime.now()
            # Step 2 process current batch through the warp_decode_func
            out = warp_decode_func(f_s_batch[start_idx:end_idx], x_s_batch[start_idx:end_idx], x_d_batch[start_idx:end_idx])
            i3 = datetime.now()
            # Step 3 write to output pinned memory buffer
            output_buffer[start_idx:end_idx].copy_(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8), non_blocking = True)
        i4 = datetime.now()
        print(f"{i4 - i1}: {i2 - i1} / {i3 - i2} / {i4 - i3}")
    torch.cuda.synchronize()

    next_time = datetime.now()

    frames = list(output_buffer.cpu().numpy())
    del output_buffer
    del streams
    torch.cuda.empty_cache()

    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {next_time - middle_time}s")
    print(f"time to transfer pinned to cpu {end_time - next_time}")
    print(f"total generation time {end_time - start_time}")

    return frames

In [75]:
all_frames = process_motion_batch4(generated_motion, motion_prev, f_s, x_s, warp_decode)

generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches:   6%|▌         | 2/35 [00:00<00:01, 17.95it/s]

0:00:00.036010: 0:00:00 / 0:00:00.036010 / 0:00:00
0:00:00.075439: 0:00:00 / 0:00:00.075439 / 0:00:00
0:00:00.064029: 0:00:00 / 0:00:00.063026 / 0:00:00.001003


Processing batches:  17%|█▋        | 6/35 [00:00<00:01, 16.23it/s]

0:00:00.066563: 0:00:00 / 0:00:00.066563 / 0:00:00
0:00:00.061030: 0:00:00 / 0:00:00.061030 / 0:00:00
0:00:00.062120: 0:00:00 / 0:00:00.062120 / 0:00:00
0:00:00.062967: 0:00:00 / 0:00:00.061967 / 0:00:00.001000


Processing batches:  29%|██▊       | 10/35 [00:00<00:01, 16.05it/s]

0:00:00.062640: 0:00:00 / 0:00:00.061640 / 0:00:00.001000
0:00:00.062044: 0:00:00 / 0:00:00.061042 / 0:00:00.001002
0:00:00.062749: 0:00:00 / 0:00:00.061748 / 0:00:00.001001
0:00:00.061031: 0:00:00 / 0:00:00.061031 / 0:00:00


Processing batches:  40%|████      | 14/35 [00:00<00:01, 15.97it/s]

0:00:00.063038: 0:00:00.001000 / 0:00:00.062038 / 0:00:00
0:00:00.064091: 0:00:00 / 0:00:00.064091 / 0:00:00
0:00:00.062798: 0:00:00 / 0:00:00.062798 / 0:00:00
0:00:00.061555: 0:00:00 / 0:00:00.061555 / 0:00:00


Processing batches:  51%|█████▏    | 18/35 [00:01<00:01, 15.90it/s]

0:00:00.062080: 0:00:00 / 0:00:00.062080 / 0:00:00
0:00:00.062520: 0:00:00 / 0:00:00.062520 / 0:00:00
0:00:00.063876: 0:00:00 / 0:00:00.063876 / 0:00:00
0:00:00.061036: 0:00:00 / 0:00:00.061036 / 0:00:00


Processing batches:  63%|██████▎   | 22/35 [00:01<00:00, 16.01it/s]

0:00:00.063087: 0:00:00 / 0:00:00.063087 / 0:00:00
0:00:00.060024: 0:00:00 / 0:00:00.060024 / 0:00:00
0:00:00.062519: 0:00:00 / 0:00:00.062519 / 0:00:00
0:00:00.061519: 0:00:00 / 0:00:00.061519 / 0:00:00


Processing batches:  74%|███████▍  | 26/35 [00:01<00:00, 15.98it/s]

0:00:00.062760: 0:00:00 / 0:00:00.062760 / 0:00:00
0:00:00.062050: 0:00:00 / 0:00:00.061050 / 0:00:00.001000
0:00:00.062519: 0:00:00 / 0:00:00.062519 / 0:00:00
0:00:00.060018: 0:00:00 / 0:00:00.060018 / 0:00:00


Processing batches:  86%|████████▌ | 30/35 [00:01<00:00, 16.03it/s]

0:00:00.064603: 0:00:00 / 0:00:00.064603 / 0:00:00
0:00:00.061019: 0:00:00 / 0:00:00.061019 / 0:00:00
0:00:00.061523: 0:00:00.001000 / 0:00:00.060523 / 0:00:00
0:00:00.062349: 0:00:00 / 0:00:00.062349 / 0:00:00


Processing batches:  97%|█████████▋| 34/35 [00:02<00:00, 15.93it/s]

0:00:00.062470: 0:00:00 / 0:00:00.062470 / 0:00:00
0:00:00.061523: 0:00:00 / 0:00:00.060523 / 0:00:00.001000
0:00:00.064560: 0:00:00 / 0:00:00.064560 / 0:00:00
0:00:00.062598: 0:00:00 / 0:00:00.062598 / 0:00:00


Processing batches: 100%|██████████| 35/35 [00:02<00:00, 16.01it/s]


time to prep x_d 0:00:00.006011s
time to warp and decode 0:00:02.231290s
time to transfer pinned to cpu 0:00:00.040614
total generation time 0:00:02.277915





In [None]:
# Generate Frames Optimized with multiple streams attemp 2

def process_motion_batch5(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size
    output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)
    compute_stream = torch.cuda.Stream()
    copy_stream = torch.cuda.Stream()

    middle_time = datetime.now()

    for i in tqdm(range(num_batches), desc="Processing batches"):

        # Step 1 calculate starting and ending index for current batch
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])

        with torch.cuda.stream(compute_stream):
            # Step 2 process current batch through the warp_decode_func which is basically a MLP + decoder
            out = warp_decode_func(f_s_batch[start_idx:end_idx], x_s_batch[start_idx:end_idx], x_d_batch[start_idx:end_idx])

        with torch.cuda.stream(copy_stream):
            copy_stream.wait_stream(compute_stream)
            # Step 3 convert to numpy array and store it in output
            output_buffer[start_idx:end_idx].copy_(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8), non_blocking = True)

    torch.cuda.synchronize()

    next_time = datetime.now()

    frames = list(output_buffer.cpu().numpy())
    del output_buffer
    del compute_stream
    del copy_stream
    torch.cuda.empty_cache()

    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {next_time - middle_time}s")
    print(f"time to transfer pinned to cpu {end_time - next_time}")
    print(f"total generation time {end_time - start_time}")

    return frames

In [45]:
all_frames = process_motion_batch5(generated_motion, motion_prev, f_s, x_s, warp_decode)

generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches: 100%|██████████| 35/35 [00:02<00:00, 15.79it/s]


time to prep x_d 0:00:00.004016s
time to warp and decode 0:00:02.262946s
time to transfer pinned to cpu 0:00:00.008503
total generation time 0:00:02.275465





In [None]:
# Generate Frames Optimized with multiple streams attemp 3

def process_motion_batch6(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size
    output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)
    compute_stream = torch.cuda.Stream()
    copy_stream = torch.cuda.Stream()

    middle_time = datetime.now()

    with torch.cuda.stream(compute_stream):
        for i in tqdm(range(num_batches), desc="Processing batches"):

            # Step 1 calculate starting and ending index for current batch
            start_idx = i * inference_batch_size
            end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])

            # Step 2 process current batch through the warp_decode_func which is basically a MLP + decoder
            out = warp_decode_func(f_s_batch[start_idx:end_idx], x_s_batch[start_idx:end_idx], x_d_batch[start_idx:end_idx])

            with torch.cuda.stream(copy_stream):
                # Step 3 convert to numpy array and store it in output
                output_buffer[start_idx:end_idx].copy_(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8), non_blocking = True)

    torch.cuda.synchronize()

    next_time = datetime.now()

    frames = list(output_buffer.cpu().numpy())
    del output_buffer
    del compute_stream
    del copy_stream
    torch.cuda.empty_cache()

    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {next_time - middle_time}s")
    print(f"time to transfer pinned to cpu {end_time - next_time}")
    print(f"total generation time {end_time - start_time}")

    return frames

In [58]:
all_frames = process_motion_batch6(generated_motion, motion_prev, f_s, x_s, warp_decode)

generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches: 100%|██████████| 35/35 [00:02<00:00, 15.78it/s]



time to prep x_d 0:00:00.003030s
time to warp and decode 0:00:02.267006s
time to transfer pinned to cpu 0:00:00.007508
total generation time 0:00:02.277544


In [None]:
# Optimized warp and decoder model

feature_3d_example = torch.randn(1, 32, 16, 64, 64, device = device)
kp_source_example = torch.randn(1, 21, 3, device = device)
kp_driving_example = torch.randn(1, 21, 3, device = device)
ret_out_example = torch.randn(1, 256, 256, 64, device = device)

traced_warping_module = torch.jit.trace(warping_module, (feature_3d_example, kp_source_example, kp_driving_example), strict = False)
traced_spade_generator = torch.jit.trace(spade_generator, ret_out_example)

def warp_decode_optimized(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = inference_cfg.flag_use_half_precision):

        # torch.compiler.cudagraph_mark_step_begin()
        # get decoder input
        ret_dct = traced_warping_module(feature_3d, kp_source = kp_source, kp_driving = kp_driving)

        # decode
        ret_dct['out'] = traced_spade_generator(feature = ret_dct['out'])

    return ret_dct

In [42]:
# Generate Frames Optimized with TorchScript computation graph

def process_motion_batch7(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 1
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size
    output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)

    middle_time = datetime.now()

    for i in tqdm(range(num_batches), desc="Processing batches"):
        i1 = datetime.now()
        # Step 1 calculate starting and ending index for current batch
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])
        i2 = datetime.now()
        # Step 2 process current batch through the warp_decode_func which is basically a MLP + decoder
        out = warp_decode_func(f_s_batch[start_idx:end_idx], x_s_batch[start_idx:end_idx], x_d_batch[start_idx:end_idx])
        i3 = datetime.now()
        # Step 3 convert to numpy array and store it in output
        output_buffer[start_idx:end_idx].copy_(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8), non_blocking = True)
        i4 = datetime.now()
        print(f"{i4 - i1}: {i2 - i1} / {i3 - i2} / {i4 - i3}")

    next_time = datetime.now()

    frames = list(output_buffer.cpu().numpy())
    del output_buffer

    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {next_time - middle_time}s")
    print(f"time to transfer pinned to cpu {end_time - next_time}")
    print(f"total generation time {end_time - start_time}")

    return frames

In [45]:
all_frames = process_motion_batch7(generated_motion, motion_prev, f_s, x_s, warp_decode_optimized)

generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches:   4%|▍         | 6/140 [00:00<00:02, 50.60it/s]

0:00:00.013561: 0:00:00 / 0:00:00.013561 / 0:00:00
0:00:00.019985: 0:00:00 / 0:00:00.019985 / 0:00:00
0:00:00.023509: 0:00:00 / 0:00:00.022509 / 0:00:00.001000
0:00:00.019511: 0:00:00 / 0:00:00.019511 / 0:00:00
0:00:00.020502: 0:00:00 / 0:00:00.020502 / 0:00:00
0:00:00.021507: 0:00:00 / 0:00:00.021507 / 0:00:00
0:00:00.023010: 0:00:00 / 0:00:00.022010 / 0:00:00.001000
0:00:00.021507: 0:00:00 / 0:00:00.021507 / 0:00:00
0:00:00.023509: 0:00:00 / 0:00:00.023509 / 0:00:00
0:00:00.018505: 0:00:00 / 0:00:00.018505 / 0:00:00

Processing batches:  12%|█▏        | 17/140 [00:00<00:02, 48.04it/s]


0:00:00.021522: 0:00:00 / 0:00:00.021522 / 0:00:00
0:00:00.021002: 0:00:00 / 0:00:00.021002 / 0:00:00
0:00:00.022504: 0:00:00 / 0:00:00.022504 / 0:00:00
0:00:00.019021: 0:00:00 / 0:00:00.019021 / 0:00:00
0:00:00.020001: 0:00:00 / 0:00:00.020001 / 0:00:00
0:00:00.020018: 0:00:00 / 0:00:00.020018 / 0:00:00
0:00:00.021011: 0:00:00 / 0:00:00.021011 / 0:00:00
0:00:00.020003: 0:00:00 / 0:00:00.020003 / 0:00:00
0:00:00.020529: 0:00:00 / 0:00:00.019529 / 0:00:00.001000
0:00:00.019003: 0:00:00 / 0:00:00.019003 / 0:00:00


Processing batches:  21%|██        | 29/140 [00:00<00:02, 48.70it/s]

0:00:00.023520: 0:00:00 / 0:00:00.023520 / 0:00:00
0:00:00.014506: 0:00:00 / 0:00:00.014506 / 0:00:00
0:00:00.026022: 0:00:00 / 0:00:00.026022 / 0:00:00
0:00:00.018002: 0:00:00 / 0:00:00.018002 / 0:00:00
0:00:00.019508: 0:00:00 / 0:00:00.019508 / 0:00:00
0:00:00.020505: 0:00:00 / 0:00:00.020505 / 0:00:00
0:00:00.019509: 0:00:00 / 0:00:00.018509 / 0:00:00.001000
0:00:00.019508: 0:00:00 / 0:00:00.019508 / 0:00:00
0:00:00.021512: 0:00:00 / 0:00:00.021512 / 0:00:00
0:00:00.019509: 0:00:00 / 0:00:00.018509 / 0:00:00.001000
0:00:00.018011: 0:00:00 / 0:00:00.018011 / 0:00:00


Processing batches:  29%|██▊       | 40/140 [00:00<00:02, 49.14it/s]

0:00:00.020001: 0:00:00 / 0:00:00.020001 / 0:00:00
0:00:00.020507: 0:00:00 / 0:00:00.020507 / 0:00:00
0:00:00.019017: 0:00:00 / 0:00:00.018010 / 0:00:00.001007
0:00:00.022999: 0:00:00 / 0:00:00.022999 / 0:00:00
0:00:00.017525: 0:00:00 / 0:00:00.017525 / 0:00:00
0:00:00.021513: 0:00:00 / 0:00:00.021513 / 0:00:00
0:00:00.022001: 0:00:00 / 0:00:00.022001 / 0:00:00
0:00:00.018017: 0:00:00 / 0:00:00.018017 / 0:00:00
0:00:00.021002: 0:00:00 / 0:00:00.021002 / 0:00:00
0:00:00.019016: 0:00:00 / 0:00:00.019016 / 0:00:00
0:00:00.019507: 0:00:00 / 0:00:00.019507 / 0:00:00


Processing batches:  36%|███▋      | 51/140 [00:01<00:01, 49.77it/s]

0:00:00.020006: 0:00:00 / 0:00:00.020006 / 0:00:00
0:00:00.022511: 0:00:00.000999 / 0:00:00.021512 / 0:00:00
0:00:00.020008: 0:00:00 / 0:00:00.020008 / 0:00:00
0:00:00.021016: 0:00:00 / 0:00:00.021016 / 0:00:00
0:00:00.017552: 0:00:00 / 0:00:00.016554 / 0:00:00.000998
0:00:00.020023: 0:00:00 / 0:00:00.020023 / 0:00:00
0:00:00.018000: 0:00:00 / 0:00:00.018000 / 0:00:00
0:00:00.019507: 0:00:00 / 0:00:00.019507 / 0:00:00
0:00:00.019512: 0:00:00 / 0:00:00.019512 / 0:00:00
0:00:00.024004: 0:00:00 / 0:00:00.024004 / 0:00:00
0:00:00.017027: 0:00:00 / 0:00:00.017027 / 0:00:00


Processing batches:  44%|████▍     | 62/140 [00:01<00:01, 49.08it/s]

0:00:00.020019: 0:00:00 / 0:00:00.020019 / 0:00:00
0:00:00.024503: 0:00:00 / 0:00:00.024503 / 0:00:00
0:00:00.017008: 0:00:00 / 0:00:00.017008 / 0:00:00
0:00:00.021108: 0:00:00 / 0:00:00.021108 / 0:00:00
0:00:00.020501: 0:00:00 / 0:00:00.020501 / 0:00:00
0:00:00.017506: 0:00:00 / 0:00:00.017506 / 0:00:00
0:00:00.019001: 0:00:00 / 0:00:00.019001 / 0:00:00
0:00:00.019507: 0:00:00 / 0:00:00.019507 / 0:00:00
0:00:00.023506: 0:00:00 / 0:00:00.023506 / 0:00:00
0:00:00.019002: 0:00:00 / 0:00:00.019002 / 0:00:00
0:00:00.019509: 0:00:00 / 0:00:00.019509 / 0:00:00


Processing batches:  53%|█████▎    | 74/140 [00:01<00:01, 49.65it/s]

0:00:00.019505: 0:00:00 / 0:00:00.019505 / 0:00:00
0:00:00.020010: 0:00:00 / 0:00:00.020010 / 0:00:00
0:00:00.020504: 0:00:00 / 0:00:00.020504 / 0:00:00
0:00:00.020513: 0:00:00 / 0:00:00.020513 / 0:00:00
0:00:00.019003: 0:00:00 / 0:00:00.019003 / 0:00:00
0:00:00.019504: 0:00:00 / 0:00:00.018504 / 0:00:00.001000
0:00:00.021516: 0:00:00 / 0:00:00.020518 / 0:00:00.000998
0:00:00.018002: 0:00:00 / 0:00:00.018002 / 0:00:00
0:00:00.020014: 0:00:00 / 0:00:00.020014 / 0:00:00
0:00:00.021002: 0:00:00 / 0:00:00.021002 / 0:00:00
0:00:00.021024: 0:00:00 / 0:00:00.021024 / 0:00:00


Processing batches:  61%|██████    | 85/140 [00:01<00:01, 50.00it/s]

0:00:00.022506: 0:00:00 / 0:00:00.022506 / 0:00:00
0:00:00.017000: 0:00:00 / 0:00:00.017000 / 0:00:00
0:00:00.021009: 0:00:00 / 0:00:00.021009 / 0:00:00
0:00:00.018508: 0:00:00 / 0:00:00.017508 / 0:00:00.001000
0:00:00.022512: 0:00:00 / 0:00:00.022512 / 0:00:00
0:00:00.021013: 0:00:00 / 0:00:00.021013 / 0:00:00
0:00:00.017504: 0:00:00 / 0:00:00.016504 / 0:00:00.001000
0:00:00.018507: 0:00:00 / 0:00:00.017507 / 0:00:00.001000
0:00:00.020061: 0:00:00 / 0:00:00.020061 / 0:00:00
0:00:00.018001: 0:00:00 / 0:00:00.018001 / 0:00:00
0:00:00.022512: 0:00:00 / 0:00:00.022512 / 0:00:00
0:00:00.017511: 0:00:00 / 0:00:00.017511 / 0:00:00


Processing batches:  68%|██████▊   | 95/140 [00:01<00:00, 49.42it/s]

0:00:00.020508: 0:00:00 / 0:00:00.020508 / 0:00:00
0:00:00.022509: 0:00:00 / 0:00:00.022509 / 0:00:00
0:00:00.017014: 0:00:00 / 0:00:00.016015 / 0:00:00.000999
0:00:00.022004: 0:00:00 / 0:00:00.022004 / 0:00:00
0:00:00.020507: 0:00:00 / 0:00:00.020507 / 0:00:00
0:00:00.017528: 0:00:00 / 0:00:00.017528 / 0:00:00
0:00:00.022006: 0:00:00 / 0:00:00.022006 / 0:00:00
0:00:00.020515: 0:00:00 / 0:00:00.019515 / 0:00:00.001000
0:00:00.019011: 0:00:00 / 0:00:00.019011 / 0:00:00
0:00:00.021507: 0:00:00 / 0:00:00.021507 / 0:00:00
0:00:00.021011: 0:00:00 / 0:00:00.021011 / 0:00:00


Processing batches:  76%|███████▌  | 106/140 [00:02<00:00, 49.37it/s]

0:00:00.018095: 0:00:00 / 0:00:00.018095 / 0:00:00
0:00:00.019011: 0:00:00 / 0:00:00.019011 / 0:00:00
0:00:00.020511: 0:00:00 / 0:00:00.020511 / 0:00:00
0:00:00.021004: 0:00:00 / 0:00:00.021004 / 0:00:00
0:00:00.021554: 0:00:00 / 0:00:00.021554 / 0:00:00
0:00:00.018508: 0:00:00 / 0:00:00.018508 / 0:00:00
0:00:00.022515: 0:00:00 / 0:00:00.022515 / 0:00:00
0:00:00.019563: 0:00:00 / 0:00:00.019563 / 0:00:00
0:00:00.017502: 0:00:00 / 0:00:00.017502 / 0:00:00
0:00:00.020507: 0:00:00 / 0:00:00.020507 / 0:00:00


Processing batches:  84%|████████▎ | 117/140 [00:02<00:00, 49.48it/s]

0:00:00.021508: 0:00:00 / 0:00:00.021508 / 0:00:00
0:00:00.020516: 0:00:00 / 0:00:00.020516 / 0:00:00
0:00:00.019002: 0:00:00 / 0:00:00.019002 / 0:00:00
0:00:00.020504: 0:00:00 / 0:00:00.020504 / 0:00:00
0:00:00.019011: 0:00:00 / 0:00:00.019011 / 0:00:00
0:00:00.020001: 0:00:00 / 0:00:00.020001 / 0:00:00
0:00:00.018010: 0:00:00 / 0:00:00.018010 / 0:00:00
0:00:00.019503: 0:00:00 / 0:00:00.019503 / 0:00:00
0:00:00.022509: 0:00:00 / 0:00:00.022509 / 0:00:00
0:00:00.018546: 0:00:00 / 0:00:00.018546 / 0:00:00


Processing batches:  91%|█████████▏| 128/140 [00:02<00:00, 49.41it/s]

0:00:00.020001: 0:00:00 / 0:00:00.020001 / 0:00:00
0:00:00.021013: 0:00:00 / 0:00:00.021013 / 0:00:00
0:00:00.020012: 0:00:00 / 0:00:00.020012 / 0:00:00
0:00:00.020004: 0:00:00 / 0:00:00.020004 / 0:00:00
0:00:00.021014: 0:00:00 / 0:00:00.021014 / 0:00:00
0:00:00.017507: 0:00:00 / 0:00:00.017507 / 0:00:00
0:00:00.023015: 0:00:00 / 0:00:00.023015 / 0:00:00
0:00:00.018507: 0:00:00 / 0:00:00.018507 / 0:00:00
0:00:00.020508: 0:00:00 / 0:00:00.020508 / 0:00:00
0:00:00.021506: 0:00:00 / 0:00:00.021506 / 0:00:00


Processing batches:  96%|█████████▌| 134/140 [00:02<00:00, 49.66it/s]

0:00:00.017513: 0:00:00 / 0:00:00.017513 / 0:00:00
0:00:00.019023: 0:00:00 / 0:00:00.019023 / 0:00:00
0:00:00.020512: 0:00:00.001001 / 0:00:00.019511 / 0:00:00
0:00:00.017514: 0:00:00 / 0:00:00.017514 / 0:00:00
0:00:00.020502: 0:00:00 / 0:00:00.020502 / 0:00:00
0:00:00.022512: 0:00:00 / 0:00:00.022512 / 0:00:00
0:00:00.018541: 0:00:00 / 0:00:00.018541 / 0:00:00
0:00:00.019000: 0:00:00 / 0:00:00.019000 / 0:00:00
0:00:00.019505: 0:00:00 / 0:00:00.019505 / 0:00:00
0:00:00.020012: 0:00:00 / 0:00:00.020012 / 0:00:00


Processing batches: 100%|██████████| 140/140 [00:02<00:00, 49.36it/s]

0:00:00.023006: 0:00:00 / 0:00:00.023006 / 0:00:00
0:00:00.018042: 0:00:00 / 0:00:00.018042 / 0:00:00

time to prep x_d 0:00:00.003082s
time to warp and decode 0:00:02.838216s
time to transfer pinned to cpu 0:00:00
total generation time 0:00:02.841298





In [28]:
# Generate Frames sequentially (frame by frame)

def process_motion_batch8(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func):
    start_time = datetime.now()

    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    # pose = generated_motion[:, -5:]
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion
    print(f"generated motion shape {generated_motion.shape}\n")
    print(f"full motion shape {full_motion.shape}\n")

    # t = torch.zeros((1, 3), dtype = torch.float32, device = device)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    print(f"x_d shape {x_d_batch.shape}\n")

    output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)

    middle_time = datetime.now()

    for i in tqdm(range(x_d_batch.shape[0]), desc="Processing batches"):
        i1 = datetime.now()
        # Step 1 process current frame through the warp_decode_func
        out = warp_decode_func(f_s, x_s, x_d_batch[i].unsqueeze(0))
        i2 = datetime.now()
        # Step 2 write to output pinned memory buffer
        output_buffer[i].copy_(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8).squeeze(0), non_blocking = False)
        i3 = datetime.now()
        print(f"{i3 - i1}: {i2 - i1} / {i3 - i2}")

    next_time = datetime.now()

    frames = list(output_buffer.cpu().numpy())
    del output_buffer

    end_time = datetime.now()

    print(f"\ntime to prep x_d {middle_time - start_time}s")
    print(f"time to warp and decode {next_time - middle_time}s")
    print(f"time to transfer pinned to cpu {end_time - next_time}")
    print(f"total generation time {end_time - start_time}")

    return frames

In [29]:
all_frames = process_motion_batch8(generated_motion, motion_prev, f_s, x_s, warp_decode)

generated motion shape torch.Size([140, 20])

full motion shape torch.Size([140, 63])

x_d shape torch.Size([140, 21, 3])



Processing batches:   1%|          | 1/140 [00:00<00:15,  8.84it/s]

0:00:00.113147: 0:00:00.112147 / 0:00:00.001000
0:00:00.020004: 0:00:00.015009 / 0:00:00.004995
0:00:00.022024: 0:00:00.016030 / 0:00:00.005994


Processing batches:   8%|▊         | 11/140 [00:00<00:03, 38.65it/s]

0:00:00.020557: 0:00:00.015561 / 0:00:00.004996
0:00:00.021091: 0:00:00.015090 / 0:00:00.006001
0:00:00.021844: 0:00:00.014842 / 0:00:00.007002
0:00:00.020509: 0:00:00.015511 / 0:00:00.004998
0:00:00.021502: 0:00:00.018503 / 0:00:00.002999
0:00:00.020016: 0:00:00.017010 / 0:00:00.003006
0:00:00.020545: 0:00:00.016037 / 0:00:00.004508
0:00:00.020001: 0:00:00.016001 / 0:00:00.004000
0:00:00.020505: 0:00:00.015505 / 0:00:00.005000
0:00:00.021012: 0:00:00.017011 / 0:00:00.004001
0:00:00.019508: 0:00:00.016508 / 0:00:00.003000


Processing batches:  11%|█▏        | 16/140 [00:00<00:02, 42.40it/s]

0:00:00.020507: 0:00:00.017508 / 0:00:00.002999
0:00:00.021016: 0:00:00.018017 / 0:00:00.002999
0:00:00.020002: 0:00:00.017002 / 0:00:00.003000
0:00:00.020621: 0:00:00.017621 / 0:00:00.003000
0:00:00.020033: 0:00:00.017038 / 0:00:00.002995
0:00:00.020503: 0:00:00.017503 / 0:00:00.003000


Processing batches:  15%|█▌        | 21/140 [00:00<00:02, 44.79it/s]

0:00:00.019506: 0:00:00.014509 / 0:00:00.004997
0:00:00.020528: 0:00:00.016003 / 0:00:00.004525
0:00:00.020011: 0:00:00.017013 / 0:00:00.002998
0:00:00.022539: 0:00:00.019540 / 0:00:00.002999
0:00:00.020020: 0:00:00.017014 / 0:00:00.003006


Processing batches:  19%|█▊        | 26/140 [00:00<00:02, 45.74it/s]

0:00:00.021002: 0:00:00.018001 / 0:00:00.003001
0:00:00.020027: 0:00:00.017029 / 0:00:00.002998
0:00:00.021534: 0:00:00.018534 / 0:00:00.003000
0:00:00.021515: 0:00:00.018516 / 0:00:00.002999
0:00:00.020517: 0:00:00.017517 / 0:00:00.003000


Processing batches:  22%|██▏       | 31/140 [00:00<00:02, 46.06it/s]

0:00:00.022507: 0:00:00.018507 / 0:00:00.004000
0:00:00.020505: 0:00:00.018505 / 0:00:00.002000
0:00:00.022013: 0:00:00.019010 / 0:00:00.003003
0:00:00.021507: 0:00:00.017511 / 0:00:00.003996
0:00:00.020002: 0:00:00.017002 / 0:00:00.003000


Processing batches:  26%|██▌       | 36/140 [00:00<00:02, 46.48it/s]

0:00:00.020507: 0:00:00.017505 / 0:00:00.003002
0:00:00.021043: 0:00:00.019047 / 0:00:00.001996
0:00:00.019118: 0:00:00.015118 / 0:00:00.004000
0:00:00.021009: 0:00:00.016505 / 0:00:00.004504
0:00:00.019509: 0:00:00.014510 / 0:00:00.004999
0:00:00.024014: 0:00:00.019002 / 0:00:00.005012


Processing batches:  29%|██▉       | 41/140 [00:00<00:02, 46.46it/s]

0:00:00.021000: 0:00:00.018000 / 0:00:00.003000
0:00:00.020009: 0:00:00.017009 / 0:00:00.003000
0:00:00.019508: 0:00:00.016509 / 0:00:00.002999


Processing batches:  36%|███▋      | 51/140 [00:01<00:01, 47.82it/s]

0:00:00.020510: 0:00:00.017504 / 0:00:00.003006
0:00:00.021025: 0:00:00.017026 / 0:00:00.003999
0:00:00.020508: 0:00:00.017508 / 0:00:00.003000
0:00:00.020015: 0:00:00.016015 / 0:00:00.004000
0:00:00.020011: 0:00:00.015013 / 0:00:00.004998
0:00:00.020004: 0:00:00.017005 / 0:00:00.002999
0:00:00.021035: 0:00:00.017530 / 0:00:00.003505
0:00:00.021505: 0:00:00.017505 / 0:00:00.004000
0:00:00.020508: 0:00:00.017508 / 0:00:00.003000
0:00:00.020614: 0:00:00.017612 / 0:00:00.003002


Processing batches:  44%|████▎     | 61/140 [00:01<00:01, 47.58it/s]

0:00:00.021103: 0:00:00.017104 / 0:00:00.003999
0:00:00.020505: 0:00:00.017505 / 0:00:00.003000
0:00:00.021013: 0:00:00.016505 / 0:00:00.004508
0:00:00.022014: 0:00:00.019015 / 0:00:00.002999
0:00:00.021003: 0:00:00.018004 / 0:00:00.002999
0:00:00.020508: 0:00:00.014506 / 0:00:00.006002
0:00:00.021014: 0:00:00.018012 / 0:00:00.003002
0:00:00.019510: 0:00:00.016513 / 0:00:00.002997
0:00:00.020019: 0:00:00.016512 / 0:00:00.003507
0:00:00.021509: 0:00:00.018508 / 0:00:00.003001

Processing batches:  51%|█████     | 71/140 [00:01<00:01, 48.18it/s]


0:00:00.020002: 0:00:00.016003 / 0:00:00.003999
0:00:00.020019: 0:00:00.015016 / 0:00:00.005003
0:00:00.021021: 0:00:00.016025 / 0:00:00.004996
0:00:00.020511: 0:00:00.016514 / 0:00:00.003997
0:00:00.019506: 0:00:00.014506 / 0:00:00.005000
0:00:00.021507: 0:00:00.017508 / 0:00:00.003999
0:00:00.021001: 0:00:00.017003 / 0:00:00.003998
0:00:00.021011: 0:00:00.018011 / 0:00:00.003000
0:00:00.020023: 0:00:00.014005 / 0:00:00.006018
0:00:00.020006: 0:00:00.015007 / 0:00:00.004999


Processing batches:  58%|█████▊    | 81/140 [00:01<00:01, 48.33it/s]

0:00:00.019511: 0:00:00.014513 / 0:00:00.004998
0:00:00.019543: 0:00:00.014508 / 0:00:00.005035
0:00:00.020516: 0:00:00.017515 / 0:00:00.003001
0:00:00.022008: 0:00:00.019009 / 0:00:00.002999
0:00:00.020510: 0:00:00.015004 / 0:00:00.005506
0:00:00.019504: 0:00:00.015505 / 0:00:00.003999
0:00:00.021035: 0:00:00.018035 / 0:00:00.003000
0:00:00.021515: 0:00:00.017505 / 0:00:00.004010
0:00:00.020004: 0:00:00.017004 / 0:00:00.003000
0:00:00.020508: 0:00:00.018508 / 0:00:00.002000
0:00:00.024017: 0:00:00.020017 / 0:00:00.004000


Processing batches:  65%|██████▌   | 91/140 [00:01<00:01, 48.17it/s]

0:00:00.019517: 0:00:00.016511 / 0:00:00.003006
0:00:00.020505: 0:00:00.017505 / 0:00:00.003000
0:00:00.021010: 0:00:00.017010 / 0:00:00.004000
0:00:00.020004: 0:00:00.017003 / 0:00:00.003001
0:00:00.020014: 0:00:00.016012 / 0:00:00.004002
0:00:00.020575: 0:00:00.015513 / 0:00:00.005062
0:00:00.022506: 0:00:00.019506 / 0:00:00.003000
0:00:00.020503: 0:00:00.015504 / 0:00:00.004999
0:00:00.020013: 0:00:00.016507 / 0:00:00.003506
0:00:00.021005: 0:00:00.017004 / 0:00:00.004001
0:00:00.020015: 0:00:00.017015 / 0:00:00.003000


Processing batches:  76%|███████▌  | 106/140 [00:02<00:00, 48.56it/s]

0:00:00.021516: 0:00:00.017008 / 0:00:00.004508
0:00:00.020002: 0:00:00.017001 / 0:00:00.003001
0:00:00.020537: 0:00:00.017539 / 0:00:00.002998
0:00:00.020018: 0:00:00.015512 / 0:00:00.004506
0:00:00.020505: 0:00:00.015504 / 0:00:00.005001
0:00:00.020508: 0:00:00.015510 / 0:00:00.004998
0:00:00.019508: 0:00:00.016003 / 0:00:00.003505
0:00:00.020513: 0:00:00.017514 / 0:00:00.002999
0:00:00.021032: 0:00:00.016030 / 0:00:00.005002
0:00:00.019506: 0:00:00.016507 / 0:00:00.002999
0:00:00.021519: 0:00:00.017516 / 0:00:00.004003


Processing batches:  83%|████████▎ | 116/140 [00:02<00:00, 48.36it/s]

0:00:00.020503: 0:00:00.016502 / 0:00:00.004001
0:00:00.019004: 0:00:00.015007 / 0:00:00.003997
0:00:00.020028: 0:00:00.017030 / 0:00:00.002998
0:00:00.021018: 0:00:00.018017 / 0:00:00.003001
0:00:00.020504: 0:00:00.017505 / 0:00:00.002999
0:00:00.020509: 0:00:00.017507 / 0:00:00.003002
0:00:00.020507: 0:00:00.015507 / 0:00:00.005000
0:00:00.019005: 0:00:00.014008 / 0:00:00.004997
0:00:00.023514: 0:00:00.020010 / 0:00:00.003504
0:00:00.021508: 0:00:00.019507 / 0:00:00.002001
0:00:00.021001: 0:00:00.018001 / 0:00:00.003000


Processing batches:  90%|█████████ | 126/140 [00:02<00:00, 48.22it/s]

0:00:00.020023: 0:00:00.017020 / 0:00:00.003003
0:00:00.021023: 0:00:00.019023 / 0:00:00.002000
0:00:00.020001: 0:00:00.017004 / 0:00:00.002997
0:00:00.020007: 0:00:00.017008 / 0:00:00.002999
0:00:00.021503: 0:00:00.017504 / 0:00:00.003999
0:00:00.019505: 0:00:00.017505 / 0:00:00.002000
0:00:00.020011: 0:00:00.017009 / 0:00:00.003002
0:00:00.021017: 0:00:00.015017 / 0:00:00.006000
0:00:00.019002: 0:00:00.015002 / 0:00:00.004000
0:00:00.021013: 0:00:00.015506 / 0:00:00.005507


Processing batches:  97%|█████████▋| 136/140 [00:02<00:00, 47.70it/s]

0:00:00.021512: 0:00:00.017515 / 0:00:00.003997
0:00:00.025012: 0:00:00.021505 / 0:00:00.003507
0:00:00.022001: 0:00:00.019001 / 0:00:00.003000
0:00:00.021503: 0:00:00.016507 / 0:00:00.004996
0:00:00.019507: 0:00:00.017002 / 0:00:00.002505
0:00:00.021508: 0:00:00.018004 / 0:00:00.003504
0:00:00.020507: 0:00:00.017508 / 0:00:00.002999
0:00:00.021015: 0:00:00.016508 / 0:00:00.004507
0:00:00.020000: 0:00:00.016001 / 0:00:00.003999
0:00:00.019507: 0:00:00.016507 / 0:00:00.003000
0:00:00.020013: 0:00:00.016003 / 0:00:00.004010


Processing batches: 100%|██████████| 140/140 [00:03<00:00, 46.62it/s]

0:00:00.019507: 0:00:00.016504 / 0:00:00.003003

time to prep x_d 0:00:00.026601s
time to warp and decode 0:00:03.010739s
time to transfer pinned to cpu 0:00:00
total generation time 0:00:03.037340





In [None]:
# Stream frame by frame

from time import sleep

def process_motion_stream(gen_motion_batch, motion_prev, f_s, x_s, warp_decode_func, output_buffer):
    # Process generated motion feature
    generated_motion = torch.cat([motion_prev[0], gen_motion_batch.reshape(-1, gen_motion_batch.shape[-1])], dim = 0)
    full_motion = torch.zeros(generated_motion.shape[0], 63, device = device)
    full_motion[:, audio_model_config['latent_mask_1']] = generated_motion

    # Process generated x_d (full motion feature)
    pitch = torch.zeros((1), dtype = torch.float32, device = device)
    yaw = torch.zeros((1), dtype = torch.float32, device = device)
    roll = torch.zeros((1), dtype = torch.float32, device = device)
    scale = torch.ones((1), dtype = torch.float32, device = device) * 1.5
    base_pose = x_c_s @ get_rotation_matrix(pitch, yaw, roll)
    x_d_batch = (scale * (base_pose + full_motion.reshape(-1, 21, 3))).squeeze(0)
    #output_buffer = torch.zeros((x_d_batch.shape[0], 512, 512, 3), dtype = torch.uint8, pin_memory = True)

    # Process frame
    for i in range(x_d_batch.shape[0]):
        # Step 1 process current frame through the warp_decode_func
        out = warp_decode_func(f_s, x_s, x_d_batch[i].unsqueeze(0))

        # Step 2 write to output pinned memory buffer
        output_buffer.put(out['out'].permute(0, 2, 3, 1).mul_(255).to(torch.uint8).squeeze(0))

    # Terminate
    output_buffer.put(None)

def display_frames(output_buffer, pre_time):
    while True:
        # Retrieve the next frame from the buffer
        frame_tensor = output_buffer.get()

        # End of processing
        if frame_tensor is None:
            break

        # Transfer the frame from pinned memory back to CPU
        frame = frame_tensor.cpu().numpy()
        result_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        # Ensure 25 fps
        if (datetime.now() - pre_time).total_seconds() < 0.024:
            sleep(0.024 - (datetime.now() - pre_time).total_seconds())

        # Display the frame using OpenCV
        print(f"time to display {datetime.now() - pre_time}")
        cv2.imshow("Video Stream", result_bgr)
        cv2.waitKey(1)
        pre_time = datetime.now()

def stream_frames(generated_motion, motion_prev, f_s, x_s, warp_decode):
    # Record start time
    start_time = datetime.now()

    # Output buffer tracking
    output_buffer = Queue(maxsize = 5000)

    # Start frame processing thread
    processing_thread = Thread(target = process_motion_stream, args=(generated_motion, motion_prev, f_s, x_s, warp_decode, output_buffer))
    processing_thread.start()

    # Start display frame
    display_frames(output_buffer, start_time)

    # Wait until terminate
    processing_thread.join()

    # Cleanup
    cv2.destroyAllWindows()

In [43]:
stream_frames(generated_motion, motion_prev, f_s, x_s, warp_decode)

time to display 0:00:00.122051
time to display 0:00:00.024516
time to display 0:00:00.024510
time to display 0:00:00.024518
time to display 0:00:00.025009
time to display 0:00:00.024514
time to display 0:00:00.024517
time to display 0:00:00.024515
time to display 0:00:00.024507
time to display 0:00:00.024512
time to display 0:00:00.025010
time to display 0:00:00.024507
time to display 0:00:00.025015
time to display 0:00:00.025009
time to display 0:00:00.024514
time to display 0:00:00.024511
time to display 0:00:00.025012
time to display 0:00:00.025203
time to display 0:00:00.024515
time to display 0:00:00.024505
time to display 0:00:00.025011
time to display 0:00:00.024508
time to display 0:00:00.024527
time to display 0:00:00.025013
time to display 0:00:00.024505
time to display 0:00:00.024504
time to display 0:00:00.025016
time to display 0:00:00.025008
time to display 0:00:00.024527
time to display 0:00:00.025046
time to display 0:00:00.024516
time to display 0:00:00.024503
time to 

In [30]:
# Save as MP4

# Remove the files if they exist
if os.path.exists(output_no_audio_path):
    os.remove(output_no_audio_path)
if os.path.exists(output_video):
    os.remove(output_video)
fps = 25  # Adjust as needed

height, width, layers = all_frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(output_no_audio_path, fourcc, fps, (width, height))

for frame in all_frames:
    video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video.release()

# Add audio to the video using ffmpeg
input_video = output_no_audio_path
input_audio = used_audio_example  # Use the path to your audio file

ffmpeg_cmd = [
    'ffmpeg',
    '-i', input_video,
    '-i', input_audio,
    '-c:v', 'copy',
    '-c:a', 'aac',
    '-shortest',
    output_video
]

try:
    subprocess.run(ffmpeg_cmd, check=True)
    os.remove(output_no_audio_path)
    print(f"Video with audio saved to {output_video}")
except subprocess.CalledProcessError as e:
    print(f"Error adding audio to video: {e}")

Video with audio saved to D:/Projects/Upenn_CIS_5650/final-project/LivePortrait/inference/animations/test5_with_audio.mp4
