In [None]:
import os

Input combinations

In [None]:

portrait_imgs = [
    '/mnt/c/Users/mjh/Downloads/live_in/t4.jpg'
]
audio_paths = [
    # '/mnt/c/Users/mjh/Downloads/live_in/i3.wav',
    # '/mnt/c/Users/mjh/Downloads/live_in/i5.wav',
    # '/mnt/c/Users/mjh/Downloads/live_in/i7.wav',
    '/mnt/c/Users/mjh/Downloads/live_in/i8.wav'
    # '/mnt/c/Users/mjh/Downloads/live_in/speech.wav'
]
model_weights_pairs = [
    # ('audio_dit/output/config.json', 'audio_dit/output/model_1023.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/model_sterotype_0_125.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/model_sterotype_1_140.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/model_1111.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/model_0.2_2.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/model_5_5.pth'),
    # ('audio_dit/output/config.json', 'audio_dit/output/redeem_5_vel_5_acc_ep_90.pth'),
    ('audio_dit/output/config.json', 'audio_dit/output/norm_no_vel_ep_60.pth'),
]


In [None]:
# import liveportrait modules
import time
import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

'''
Common modules
'''
from src.utils.helper import load_model
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.io import load_image_rgb

'''
Main module for inference
'''
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

'''
Main function for inference
'''

def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

In [None]:
image_path = portrait_imgs[0]
offset_std = 0.002
# Load and prepare image
img_rgb = load_image_rgb(image_path)
img_crop_256x256 = cv2.resize(img_rgb, (256, 256))
I_s = prepare_source(img_crop_256x256)

# Get keypoint info
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_exp = x_s_info['exp']
x_s = transform_keypoint(x_s_info)

In [None]:
x_s_info.keys()


In [None]:
max_val = [0.018, 0.018, 0.018]
rest_val = [0.006, -0.002, -0.002]
min_val = [-0.005, -0.015, -0.015]

In [None]:
image_path = portrait_imgs[0]
offset_std = 0.002
# Load and prepare image
img_rgb = load_image_rgb(image_path)
img_crop_256x256 = cv2.resize(img_rgb, (256, 256))
I_s = prepare_source(img_crop_256x256)

# Get keypoint info
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_exp = x_s_info['exp']
x_scale = x_s_info['scale']
x_s = transform_keypoint(x_s_info)

t_identity = torch.zeros((1, 3), dtype=torch.float32, device=device)
pitch_identity = torch.zeros((1), dtype=torch.float32, device=device)
yaw_identity = torch.zeros((1), dtype=torch.float32, device=device)
roll_identity = torch.zeros((1), dtype=torch.float32, device=device)
scale_identity = torch.ones((1), dtype=torch.float32, device=device) * 1.5

if False:
    t_s = x_s_info['t']
    pitch_s = x_s_info['pitch'] - 10
    yaw_s = yaw_identity
    roll_s = roll_identity
    scale_s = x_s_info['scale']

pitch = pitch_identity
yaw = yaw_identity
roll = roll_identity

# Extract features
f_s = extract_feature_3d(I_s)
x_exp = torch.zeros_like(x_exp)
x_exp = x_exp.reshape(-1)
x_exp[4] = -0.018
x_exp[33] = 0.018
x_exp[45] = 0.011
x_exp[48] = 0.011
x_exp = x_exp.reshape(1, 21, 3)

x_d_i = x_scale * (x_c_s @ get_rotation_matrix(pitch, yaw, roll) + x_exp) + t_identity

# Add random offset to final dimension
random_offset = torch.randn(1, device=device) * offset_std
x_s[:, :, -1] += random_offset

# Render with modified latents
out = warp_decode(f_s, x_s, x_d_i)
out_img = (out['out'].permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
print(out_img.shape)


In [None]:
from l2cs import Pipeline, render
import cv2
import os
import torch

gaze_pipeline = Pipeline(
    weights='/mnt/c/Users/mjh/Downloads/L2CSNet_gaze360.pkl',
    arch='ResNet50',
    device=torch.device('cpu') # or 'gpu'
)
frame = out_img[0]
results = gaze_pipeline.step(frame)
gaze_pitch = results.pitch
gaze_yaw = results.yaw
# Convert radians to degrees
gaze_pitch = gaze_pitch * 180 / np.pi
gaze_yaw = gaze_yaw * 180 / np.pi

print(f'gaze_pitch: {gaze_pitch}, gaze_yaw: {gaze_yaw}')

In [None]:
import matplotlib.pyplot as plt
plt.imshow(out_img[0])
# plt.show()