In [None]:
import os

### 1. Decide input combinations

In [None]:
run_syncnet = True

Input combinations

In [None]:

portrait_imgs = [
    '/mnt/c/Users/mjh/Downloads/live_in/t4.jpg'
]
audio_paths = [
    # '/mnt/c/Users/mjh/Downloads/live_in/i3.wav',
    '/mnt/c/Users/mjh/Downloads/live_in/i5.wav',
    '/mnt/c/Users/mjh/Downloads/live_in/i7.wav',
    '/mnt/c/Users/mjh/Downloads/live_in/i8.wav'
]
model_weights_pairs = [
    # ('audio_dit/output/config.json', 'audio_dit/output/model_1023.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/model_sterotype_0_125.pth'),
    ('audio_dit/output/config.json', 'audio_dit/output/model_sterotype_1_140.pth'),
]

Sampling options

In [None]:
cfg_scale_opts = [
    0,
    0.25,
    0.5,
    0.75,
    # 1,
    # 1.5,
    # 2
]
mouth_open_ratio_opts = [
    # 0.1,
    0.15,
    # 0.2,
    0.225,
    0.25,
    0.275,
    # 0.3
]
subtract_avg_motion_opts = [
    False,
    # True
]


2. Generate Videos & Eval Syncnet Each outpout

In [None]:
# import liveportrait modules
import time
import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

'''
Common modules
'''
from src.utils.helper import load_model
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.io import load_image_rgb

'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

'''
Main function for inference
'''

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

In [None]:
# import wav2vec modules
import json
import torch
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import os
import numpy as np
from typing import List, Tuple
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Constants
MODEL_NAME = "facebook/wav2vec2-base-960h"
TARGET_SAMPLE_RATE = 16000
FRAME_RATE = 25
SECTION_LENGTH = 3
OVERLAP = 10

DB_ROOT = 'vox2-audio-tx'
LOG = 'log'
AUDIO = 'audio/audio'
OUTPUT_DIR = 'audio_encoder_output'
BATCH_SIZE = 16


# Move model and processor to global scope
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)

    original_sample_rate = sample_rate

    if sample_rate != TARGET_SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    # print(file_path," waveform.shape ",waveform.shape)

    # Calculate section length and overlap in samples
    section_samples = SECTION_LENGTH * 16027
    overlap_samples = int(OVERLAP / FRAME_RATE * TARGET_SAMPLE_RATE)
    # print('section_samples',section_samples,'overlap_samples',overlap_samples)

    # pad 10 overlap at the beginning
    waveform = torch.nn.functional.pad(waveform, (overlap_samples, 0))
    # Pad if shorter than 3 seconds
    if waveform.shape[1] < section_samples:
        waveform = torch.nn.functional.pad(waveform, (0, section_samples - waveform.shape[1]))
        return [waveform.squeeze(0)], original_sample_rate

    # Split into sections with overlap
    sections = []
    start = 0

    # print('starting to segment', file_path)
    while start < waveform.shape[1]:
        end = start + section_samples
        if end >= waveform.shape[1]:
            tmp=waveform[:, start:min(end, waveform.shape[1])]
            tmp = torch.nn.functional.pad(tmp, (0, section_samples - tmp.shape[1]))
            sections.append(tmp.squeeze(0))
            # print(tmp.shape)
            break
        else:
            sections.append(waveform[:, start:min(end, waveform.shape[1])].squeeze(0))

        start = int(end - overlap_samples)


    return file_path, sections

def inference_process_wav_file(path):
    audio_path, segments = load_and_process_audio(path)
    # print(audio_path,segments)
    segments = np.array(segments)

    inputs = wav2vec_processor(segments, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt", padding=True).input_values.to(device)

    with torch.no_grad():
        outputs = wav2vec_model(inputs)
        latent = outputs.last_hidden_state

        seq_length = latent.shape[1]
        new_seq_length = int(seq_length * (FRAME_RATE / 50))

        latent_features_interpolated = F.interpolate(latent.transpose(1,2),
                                                     size=new_seq_length,
                                                     mode='linear',
                                                     align_corners=True).transpose(1,2)
    return latent_features_interpolated


In [None]:
# import dit modules
from audio_dit.inference import InferenceManager, get_model
from audio_dit.dataset import load_and_process_pair

def process_motion_batch(gen_motion_batch, motion_prev, f_s, x_s, x_c_s, x_s_info, audio_model_config, warp_decode_func):
    frames = []
    B, T, feat_count = gen_motion_batch.shape
    full_motion = gen_motion_batch.reshape(B*T, feat_count)
    # full_motion = torch.cat([motion_prev[0], full_motion], dim=0)

    pose = full_motion[:, -5:]
    exp = full_motion[:, :]
    # print("pose", pose.shape, "exp", exp.shape)
    # exp_mask = torch.zeros_like(full_motion[0][0])
    # pos_mouth = [14, 17, 19, 20]
    # eye and mouth [15, 16, 18]
    # pos_eye & forehead = [1, 2, 11, 12, 13] # 1.z is mouth, 12 small eye
    # shape [0, 3, 4, 7, 8, 9, 10] # may include shape dependent pose
    # pos_cloth = [5, 6]

    # for p in pos_mouth:
    #     exp_mask[p * 3:(p + 1) * 3] = 1

    # # exp_mask = exp_mask.reshape(21,3)
    # print(full_motion[0, 10:15, :] * exp_mask)

    t_identity = torch.zeros((1, 3), dtype=torch.float32, device=device)
    pitch_identity = torch.zeros((1), dtype=torch.float32, device=device)
    yaw_identity = torch.zeros((1), dtype=torch.float32, device=device)
    roll_identity = torch.zeros((1), dtype=torch.float32, device=device)
    scale_identity = torch.ones((1), dtype=torch.float32, device=device) * 1.5

    use_identity_pose = True
    if use_identity_pose:
        # t_s = t_identity
        # pitch_s = pitch_identity
        # yaw_s = yaw_identity
        # roll_s = roll_identity
        # scale_s = scale_identity
        t_s = x_s_info['t']
        pitch_s = x_s_info['pitch'] - 10
        yaw_s = yaw_identity
        roll_s = roll_identity
        scale_s = x_s_info['scale']
    else:
        t_s = x_s_info['t']
        pitch_s = x_s_info['pitch']
        yaw_s = x_s_info['yaw']
        roll_s = x_s_info['roll']
        scale_s = x_s_info['scale']
    t = t_s
    pitch = pitch_s
    yaw = yaw_s
    roll = roll_s
    scale = scale_s

    full_63_exp = torch.zeros(full_motion.shape[0], 63, device=device)
    for i, dim in enumerate(audio_model_config['latent_mask_1']):
        # print(i, dim)
        full_63_exp[:, dim] = exp[:, i]
    full_motion = full_63_exp.reshape(-1, 63)

    x_d_list = []

    for i in tqdm(range(full_motion.shape[0]), desc="Generating x_d"):
        motion = full_motion[i].reshape(21, 3)

        # Initialize empty tensors

        # # Extract values from motion
        exp = motion #* exp_mask
        # pitch = pose[i, 0:1]
        # yaw = pose[i, 1:2]
        # roll = pose[i, 2:3]
        # t_x = pose[i, 3:4]
        # t_y = pose[i, 4:5]
        t = torch.tensor(t, device=device)

        x_d_i = scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + exp) + t
        x_d_list.append(x_d_i.squeeze(0))

    x_d_batch = torch.stack(x_d_list, dim=0)
    f_s_batch = f_s.expand(x_d_batch.shape[0], -1, -1, -1, -1)
    x_s_batch = x_s.expand(x_d_batch.shape[0], -1, -1)

    inference_batch_size = 4
    num_batches = (x_d_batch.shape[0] + inference_batch_size - 1) // inference_batch_size

    frames = []
    for i in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = i * inference_batch_size
        end_idx = min((i + 1) * inference_batch_size, x_d_batch.shape[0])

        batch_f_s = f_s_batch[start_idx:end_idx]
        batch_x_s = x_s_batch[start_idx:end_idx]
        batch_x_d = x_d_batch[start_idx:end_idx]

        out = warp_decode_func(batch_f_s, batch_x_s, batch_x_d)

        # Convert to numpy array
        batch_frames = (out['out'].permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        frames.extend(list(batch_frames))

    return frames

def write_video(all_frames, audio_path, output_path):
    output_no_audio_path = 'output/audio_driven_video_no_audio.mp4'
    output_video = output_path

    # Remove the files if they exist
    if os.path.exists(output_no_audio_path):
        os.remove(output_no_audio_path)
    if os.path.exists(output_video):
        os.remove(output_video)
    fps = 25  # Adjust as needed

    height, width, layers = all_frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(output_no_audio_path, fourcc, fps, (width, height))

    for frame in all_frames:
        video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    video.release()

    # Add audio to the video using ffmpeg
    input_video = output_no_audio_path
    input_audio = audio_path  # Use the path to your audio file

    ffmpeg_cmd = [
        'ffmpeg',
        '-i', input_video,
        '-i', input_audio,
        '-c:v', 'copy',
        '-c:a', 'aac',
        '-shortest',
        output_video
    ]

    try:
        subprocess.run(ffmpeg_cmd, check=True)
        os.remove(output_no_audio_path)
        # print(f"Video with audio saved to {output_video}")
    except subprocess.CalledProcessError as e:
        print(f"Error adding audio to video: {e}")

In [None]:
# audio-driven inference function
def inference_one_input(audio_path, portrait_path, output_vid_path, inference_manager, audio_model_config, cfg_s, mouth_ratio, subtract_avg_motion):
    # load audio
    audio_latent = inference_process_wav_file(audio_path)
    audio_seq = audio_latent[:, 10:, :]
    audio_prev = audio_latent[:, :10, :]
    # load portrait
    img_rgb = load_image_rgb(portrait_path)
    source_rgb_lst = [img_rgb]
    source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
    img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256
    I_s = prepare_source(img_crop_256x256)
    x_s_info = get_kp_info(I_s)
    x_c_s = x_s_info['kp']
    x_s = transform_keypoint(x_s_info)
    f_s = extract_feature_3d(I_s)

    # inference
    mouth_open_ratio_val = mouth_ratio
    mouth_open_ratio_input = torch.tensor([mouth_open_ratio_val], device=device).unsqueeze(0)
    out_motion = torch.tensor([], device=device)
    B, T, audio_dim = audio_seq.shape
    motion_dim = audio_model_config['x_dim']
    shape_in = x_c_s.reshape(1, -1).to(device)
    this_audio_prev = torch.zeros(1, 10, audio_dim, device=device)
    this_motion_prev = torch.zeros(1, 10, motion_dim , device=device)
    init_motion_prev = this_motion_prev
    for batch_index in range(0, audio_seq.shape[0]):
        # print(f'batch_index: {batch_index}')
        # this_motion_prev = motion_prev[batch_index:batch_index+1]
        # this_audio_prev = audio_prev[batch_index:batch_index+1]
        this_audio_prev = audio_prev[batch_index:batch_index+1]
        generated_motion = inference_manager.inference(audio_seq[batch_index:batch_index+1],
                                                    shape_in, this_motion_prev, this_audio_prev, #seq_mask=seq_mask,
                                                    cfg_scale=cfg_s,
                                                    mouth_open_ratio = mouth_open_ratio_input,
                                                    denoising_steps=10)
        this_motion_prev = generated_motion[:, -10:, :]
        # generated_motion = generated_motion.reshape(-1, 6)

        if subtract_avg_motion:
            generated_motion = generated_motion - torch.mean(generated_motion, dim=-1, keepdim=True)
        out_motion = torch.cat((out_motion, generated_motion), dim=0)
    generated_motion = out_motion
    all_frames = process_motion_batch(generated_motion, this_motion_prev, f_s, x_s, x_c_s, x_s_info, audio_model_config, warp_decode)
    # write to video
    write_video(all_frames, audio_path, output_vid_path)


In [None]:
# syncnet inference function
from syncnet.syncnet import syncnet_inference

def call_syncnet(output_vid_path, tmp_dir):

    # Extract the reference from the video filename
    video_basename = os.path.basename(output_vid_path)
    reference = os.path.splitext(video_basename)[0]

    results, activesd = syncnet_inference(output_vid_path, reference, tmp_dir, keep_output=False)
    if results:
        # print("\nSyncNet Results:")
        # print(f"AV Offset: {results['av_offset']}")
        # print(f"Confidence: {results['confidence']}")
        # print(f"Min Dist: {results['min_dist']}")
        # print(f"Framewise Conf Shape: {np.array(results['framewise_conf']).shape}")
        # print(f"ActiveSD Shape: {np.array(activesd).shape}")
        return results['confidence'], results['min_dist']

def change_working_dir_to_script_location():
    os.chdir('/mnt/e/wsl_projects/LivePortrait/')


In [None]:
# main function
change_working_dir_to_script_location()
sync_tmp_dir = './sync_output/tmp'
sync_tmp_dir_abs = os.path.abspath(sync_tmp_dir)
os.makedirs(sync_tmp_dir_abs, exist_ok=True)
print(f'sync_tmp_dir_abs: {sync_tmp_dir_abs}')
for model_weights_pair in model_weights_pairs:
    change_working_dir_to_script_location()
    config_path, weight_path = model_weights_pair
    weight_basename = os.path.basename(weight_path).split('.')[0]
    output_root = f'./sync_output/{time.strftime("%Y-%m-%d-%H-%M", time.localtime(time.time()))}_{weight_basename}'
    output_root_abs = os.path.abspath(output_root)
    os.makedirs(output_root_abs, exist_ok=True)
    print(f'output_root_abs: {output_root_abs}')

    print("config_path exists:", os.path.exists(config_path))
    audio_model_config = json.load(open(config_path))
    inference_manager = get_model(config_path, weight_path, device)
    result_dict = {}
    for cfg_s in cfg_scale_opts:
        for mouth_ratio in mouth_open_ratio_opts:
            for subtract_avg_motion in subtract_avg_motion_opts:
                change_working_dir_to_script_location()
                config_name = f'cfg_{cfg_s}_mouth_{mouth_ratio}_subtract_{subtract_avg_motion}'
                output_parent = os.path.join(output_root_abs, config_name)
                os.makedirs(output_parent, exist_ok=True)
                print(f'processing {config_name}')
                result_dict[config_name] = {}
                for audio_path in audio_paths:
                    audio_basename = os.path.basename(audio_path)
                    for portrait_path in portrait_imgs:
                        change_working_dir_to_script_location()
                        portrait_basename = os.path.basename(portrait_path)
                        vid_name = f'audio_{audio_basename}_img_{portrait_basename}.mp4'
                        output_vid_path = os.path.join(output_parent, vid_name)
                        print(f'    processing {output_vid_path}')

                        inference_one_input(audio_path, portrait_path, output_vid_path, inference_manager, audio_model_config, cfg_s, mouth_ratio, subtract_avg_motion)
                        output_vid_path_abs = os.path.abspath(output_vid_path)
                        if run_syncnet:
                            s_c, s_d = call_syncnet(output_vid_path_abs, sync_tmp_dir_abs)
                            result_dict[config_name][vid_name] = [s_c, s_d]
                            print(f'    {vid_name} done, confidence: {s_c}, min_dist: {s_d}')
                            json_output_path = os.path.join(output_root_abs, 'result_dict.json')
                            json.dump(result_dict, open(json_output_path, 'w'), indent=4)


In [None]:
# s_c, s_d = call_syncnet(output_vid_path_abs, sync_tmp_dir_abs)
# print(f'    {vid_name} done, confidence: {s_c}, min_dist: {s_d}')