In [1]:
#!/usr/bin/env python
import sys
import os

# change base folder
os.chdir('../')

import torch
import cv2
import numpy as np
from flame_model.FLAME import FLAMEModel
from renderer.renderer import Renderer
import argparse
import torch.nn.functional as F
from pytorch3d.transforms import matrix_to_euler_angles
import subprocess
import tempfile 
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import librosa
import soundfile as sf
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import wandb
import glob
from models.stage2 import CodeTalker
import yaml
from models import get_model
from base.baseTrainer import load_state_dict
from types import SimpleNamespace
from transformers import AutoProcessor, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
import pickle
from scipy.signal import savgol_filter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flame    = FLAMEModel(n_shape=300,n_exp=50).to(device)
renderer = Renderer(render_full_head=True).to(device)



In [None]:
def load_and_flatten_yaml(config_path):
    """
    Loads the YAML file and flattens the structure so that
    all sub-keys under top-level sections (e.g., DATA, NETWORK, etc.)
    appear in a single dictionary without the top-level keys.
    """
    with open(config_path, 'r') as f:
        full_config = yaml.safe_load(f)

    # Flatten the dict by merging all sub-dicts
    flattened_config = {}
    for top_level_key, sub_dict in full_config.items():
        # sub_dict should itself be a dict of key-value pairs
        if isinstance(sub_dict, dict):
            # Merge each sub-key into flattened_config
            for k, v in sub_dict.items():
                flattened_config[k] = v
        else:
            # In case there's a non-dict top-level key (unlikely but possible)
            flattened_config[top_level_key] = sub_dict

    return SimpleNamespace(**flattened_config)

# 1. Load YAML data into a Python dictionary
global cfg
cfg = load_and_flatten_yaml("config/talkinghead-1kh/demo.yaml")

In [None]:
model = get_model(cfg)
model = model.to(device)

if os.path.isfile(cfg.model_path):
    print("=> loading checkpoint '{}'".format(cfg.model_path))
    checkpoint = torch.load(cfg.model_path, map_location=lambda storage, loc: storage.cpu())
    load_state_dict(model, checkpoint['state_dict'], strict=False)
    print("=> loaded checkpoint '{}'".format(cfg.model_path))
else:
    raise RuntimeError("=> no checkpoint flound at '{}'".format(cfg.model_path))

model.eval()

checkpoint = torch.load(cfg.model_path, map_location=lambda storage, loc: storage.cpu())

audio_dir   = "demo/input"
encoded_dir = "demo/output"

In [None]:
# Load sample style
style_sample_path = "/root/Datasets/joint_data/npz/oonc4u-Adbc_0001_S366_E617_L271_T84_R735_B548.npz"
flame_param       = np.load(style_sample_path, allow_pickle=True)

if 'pose' in flame_param:
    expr   = flame_param["exp"].reshape(-1,50)
    jaw    = flame_param["pose"][:,3:6].reshape(-1,3)
    gpose  = flame_param["pose"][:,0:3].reshape(-1,3)
    gpose  = gpose - gpose.mean(axis=0, keepdims=True)
elif 'pose_params' in flame_param:
    expr   = flame_param['expression_params'].reshape(-1, 50)
    jaw    = flame_param["jaw_params"].reshape(-1, 3)
    gpose  = flame_param["pose_params"].reshape(-1, 3)
    gpose  = gpose - gpose.mean(axis=0, keepdims=True)
else:
    expr    = flame_param["exp"].reshape((flame_param["exp"].shape[0], -1))
    gpose   = flame_param["gpose"].reshape((flame_param["gpose"].shape[0], -1))
    jaw     = flame_param["jaw"].reshape((flame_param["jaw"].shape[0], -1))

# Apply Savitzky-Golay filter along the time axis for gpose (removes tracker's flickering) (axis=0)
gpose = savgol_filter(gpose, window_length=7, polyorder=2, axis=0)
eyelids = np.ones((expr.shape[0], 2))
concat_blendshapes = np.concatenate((expr, gpose, jaw, eyelids), axis=1)
style_tensor = torch.Tensor(concat_blendshapes).to(device='cuda').unsqueeze(0)

In [None]:
# Switch to eval and predict blenshapes from audio
model.eval()

for wav_file in glob.glob(os.path.join(audio_dir,"*.wav")):
    print('Generating facial animation for {}...'.format(wav_file))
    test_name  = os.path.basename(wav_file).split(".")[0]

    predicted_blendhsapes_path = os.path.join(encoded_dir, test_name+'.npz')
    speech_array, _ = librosa.load(wav_file, sr=16000)

    # Use Wav2Vec audio features
    processor = Wav2Vec2FeatureExtractor.from_pretrained(cfg.wav2vec2model_path)
    print(cfg.wav2vec2model_path)

    audio_feature = np.squeeze(processor(speech_array, sampling_rate=16000).input_values)
    audio_feature = np.reshape(audio_feature, (-1, audio_feature.shape[0]))
    audio_feature = torch.FloatTensor(audio_feature).to(device='cuda')

    with torch.no_grad():
        blendshapes_out = model.predict(audio_feature, target_style=style_tensor)
        
        exp_out, gpose_out, jaw_out, eyelids_out = torch.split(blendshapes_out, [50, 3, 3, 2], dim=-1)

        exp_out, gpose_out, jaw_out, eyelids_out = exp_out.squeeze(1), gpose_out.squeeze(1), jaw_out.squeeze(1), eyelids_out.squeeze(1)

        # Filter jitter
        gpose_out =  torch.FloatTensor(savgol_filter(gpose_out.squeeze(0).cpu().numpy(), window_length=7, polyorder=2, axis=0)).unsqueeze(0)

        print(f'Exp shape: {exp_out.shape}, Gpose shape: {gpose_out.shape}, Jaw shape: {jaw_out.shape}, Eyelids shape: {eyelids_out.shape}')

        np.savez(predicted_blendhsapes_path, exp=exp_out.detach().cpu().numpy(), gpose=gpose_out.detach().cpu().numpy(), jaw=jaw_out.detach().cpu().numpy(), eyelids=eyelids_out.detach().cpu().numpy())
        print(f'Save facial animation in {predicted_blendhsapes_path}')

In [None]:
def get_vertices_from_blendshapes(expr, gpose, jaw, eyelids):

    # Load the encoded file
    expr_tensor    = expr.to(device)
    gpose_tensor   = gpose.to(device)
    jaw_tensor     = jaw.to(device)
    eyelids_tensor = eyelids.to(device)

    target_shape_tensor = torch.zeros(expr_tensor.shape[0], 300).expand(expr_tensor.shape[0], -1).to(device)

    I = matrix_to_euler_angles(torch.cat([torch.eye(3)[None]], dim=0),"XYZ").to(device)

    eye_r    = I.clone().to(device).squeeze()
    eye_l    = I.clone().to(device).squeeze()
    eyes     = torch.cat([eye_r,eye_l],dim=0).expand(expr_tensor.shape[0], -1).to(device)

    pose = torch.cat([gpose_tensor, jaw_tensor], dim=-1).to(device)

    flame_output_only_shape,_ = flame.forward(shape_params=target_shape_tensor, 
                                              expression_params=expr_tensor, 
                                              pose_params=pose, 
                                              eye_pose_params=eyes)
    return flame_output_only_shape.detach()

In [None]:
def update(frame_inx, renderer_output_blendshapes, axes):
    # Select the frames to plot
    frame = renderer_output_blendshapes['rendered_img'][frame_inx].detach().cpu().numpy().transpose(1, 2, 0)

    # Update the second subplot
    axes.clear()
    axes.imshow((frame * 255).astype(np.uint8))
    axes.axis('off')
    axes.set_title(f'Frame Stage 1 (Blendshape) {frame_inx + 1}')

# Function to create and save the video
def create_and_save_video(encoded_dir,file_name, renderer,audio_dir,output_dir):
    base_name = os.path.basename(file_name).replace('.npz', '')
    print(base_name)
    
    blendshapes_data_encoded_exp    = np.load(f'{encoded_dir}/{base_name}.npz')['exp'].reshape(-1, 50)
    blendshapes_data_encoded_gpose  = np.load(f'{encoded_dir}/{base_name}.npz')['gpose'].reshape(-1, 3)
    blendshapes_data_encoded_jaw    = np.load(f'{encoded_dir}/{base_name}.npz')['jaw'].reshape(-1, 3)
    blendshapes_data_encoded_eyelids = np.load(f'{encoded_dir}/{base_name}.npz')['eyelids'].reshape(-1, 2)

    blendshapes_data_encoded_exp     = torch.tensor(blendshapes_data_encoded_exp, dtype=torch.float32).to(device)
    blendshapes_data_encoded_gpose   = torch.tensor(blendshapes_data_encoded_gpose, dtype=torch.float32).to(device)
    blendshapes_data_encoded_jaw     = torch.tensor(blendshapes_data_encoded_jaw, dtype=torch.float32).to(device)
    blendshapes_data_encoded_eyelids = torch.tensor(blendshapes_data_encoded_eyelids, dtype=torch.float32).to(device)
    
    # Compute vertices from blendshapes
    blendshapes_derived_vertices = get_vertices_from_blendshapes(blendshapes_data_encoded_exp,blendshapes_data_encoded_gpose, blendshapes_data_encoded_jaw, blendshapes_data_encoded_eyelids)
    print(blendshapes_derived_vertices.shape)

    # ==== Camera ====
    cam = torch.tensor([5, 0, 0], dtype=torch.float32).unsqueeze(0).to(device)
    cam = cam.expand(blendshapes_derived_vertices.shape[0], -1)

    # Render the frames
    renderer_output_blendshapes  = renderer.forward(blendshapes_derived_vertices, cam)

    N = renderer_output_blendshapes['rendered_img'].shape[0] # Number of frames

    # Create a figure with two subplots
    fig, axes = plt.subplots(1, 1, figsize=(5, 5))

    # Create an animation
    ani = animation.FuncAnimation(
        fig, 
        update, 
        frames=N, 
        fargs=(renderer_output_blendshapes, axes),
        interval=100
    )

    # Save the animation as a video file
    video_file = f'{output_dir}/{base_name}.mp4'
    ani.save(video_file, writer='ffmpeg', fps=25)
    print(f"Video saved as {video_file}")
    
    # =============== Add audio to the video ===============
    
    # Add audio to the video
    audio_file = f'{audio_dir}/{base_name}.wav'
    output_with_audio = f'{output_dir}/{base_name}_new_with_audio.mp4'
    if os.path.exists(audio_file):
        cmd = f'ffmpeg -y -i {video_file} -i {audio_file} -c:v copy -c:a aac -strict experimental {output_with_audio}'
        subprocess.run(cmd, shell=True)
        print(f"Video with audio saved as {output_with_audio}")
    else:
        print(f"Audio file {audio_file} not found")

In [None]:
# Directory containing encoded files
encoded_dir = 'demo/output'
audio_dir   = 'demo/input'
output_dir  = 'demo/video'

# Check if the directory exists, if not, create it
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Directory created: {output_dir}")
else:
    print(f"Directory already exists: {output_dir}")

counter =  10
# Iterate over all files in the encoded directory
for file_name in os.listdir(encoded_dir):
    if counter == 0:
        break
    if file_name.endswith('.npz'):
        create_and_save_video(encoded_dir,file_name,renderer,audio_dir,output_dir)
    counter += 1