In [None]:
#!/usr/bin/env python
import sys
import os

# change base folder
os.chdir('../')

In [None]:
import torch
import cv2
import numpy as np
from flame_model.FLAME import FLAMEModel
from renderer.renderer import Renderer
import argparse
import torch.nn.functional as F
from pytorch3d.transforms import matrix_to_euler_angles
import subprocess
import tempfile 
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import librosa
import soundfile as sf
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import wandb
import glob
from models.stage2 import CodeTalker
import yaml
from models import get_model
from base.baseTrainer import load_state_dict
from types import SimpleNamespace
from transformers import AutoProcessor, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
import pickle
import itertools
from pytorch3d.renderer import look_at_view_transform
import random
from scipy.signal import savgol_filter


In [None]:
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flame    = FLAMEModel(n_shape=300,n_exp=50).to(device)
renderer = Renderer(render_full_head=True).to(device)

In [None]:
def get_vertices_from_blendshapes(expr, gpose, jaw, eyelids=None):

    # Load the encoded file
    expr_tensor    = expr.to(device)
    gpose_tensor   = gpose.to(device)
    jaw_tensor     = jaw.to(device)
    
    if eyelids is not None:
        eyelids_tensor = eyelids.to(device)

    target_shape_tensor = torch.zeros(expr_tensor.shape[0], 300).expand(expr_tensor.shape[0], -1).to(device)

    I = matrix_to_euler_angles(torch.cat([torch.eye(3)[None]], dim=0),"XYZ").to(device)

    eye_r    = I.clone().to(device).squeeze()
    eye_l    = I.clone().to(device).squeeze()
    eyes     = torch.cat([eye_r,eye_l],dim=0).expand(expr_tensor.shape[0], -1).to(device)

    pose = torch.cat([gpose_tensor, jaw_tensor], dim=-1).to(device)

    flame_output_only_shape,_ = flame.forward(shape_params=target_shape_tensor, 
                                               expression_params=expr_tensor, 
                                               pose_params=pose, 
                                               eye_pose_params=eyes)
    return flame_output_only_shape.detach()

In [None]:
def update(frame_inx, renderer_output_blendshapes, axes):
    # Select the frames to plot
    frame = renderer_output_blendshapes[frame_inx].detach().cpu().numpy().transpose(1, 2, 0)

    # Update the second subplot
    axes.clear()
    axes.imshow((frame * 255).astype(np.uint8))
    axes.set_position([0, 0, 1, 1])
    axes.axis('off')
    #axes.set_title(f'Frame Stage 1 (Blendshape) {frame_inx + 1}')

# Function to create and save the video
def create_and_save_video(encoded_dir,file_name, renderer,audio_dir,output_dir):
    base_name = os.path.basename(file_name).replace('.npz', '')
    print(base_name)
    
    flame_param = np.load(f'{encoded_dir}/{base_name}.npz')

    for key in flame_param.keys():
        print(key, flame_param[key].shape)

    if 'pose' in flame_param:
        blendshapes_data_encoded_expr   = flame_param['exp'].reshape(-1, 50)
        blendshapes_data_encoded_jaw    = flame_param["pose"][:,3:6].reshape(-1, 3)
        blendshapes_data_encoded_gpose  = flame_param["pose"][:,0:3].reshape(-1, 3)
        blendshapes_data_encoded_gpose  = blendshapes_data_encoded_gpose - blendshapes_data_encoded_gpose.mean(axis=0, keepdims=True)
    elif 'pose_params' in flame_param:
        blendshapes_data_encoded_expr   = flame_param['expression_params'].reshape(-1, 50)
        blendshapes_data_encoded_jaw    = flame_param["jaw_params"].reshape(-1, 3)
        blendshapes_data_encoded_gpose  = flame_param["pose_params"].reshape(-1, 3)
        blendshapes_data_encoded_gpose  = blendshapes_data_encoded_gpose - blendshapes_data_encoded_gpose.mean(axis=0, keepdims=True)
    else:
        blendshapes_data_encoded_expr    = flame_param['exp'].reshape(-1, 50)
        blendshapes_data_encoded_gpose   = flame_param["gpose"].reshape(-1, 3)
        blendshapes_data_encoded_jaw     = flame_param['jaw'].reshape(-1, 3)
        #blendshapes_data_encoded_eyelids = flame_param['eyelids'].reshape(-1, 2)
    
    blendshapes_data_encoded_gpose = savgol_filter(blendshapes_data_encoded_gpose, window_length=7, polyorder=2, axis=0)

    print("expr ",blendshapes_data_encoded_expr.shape)
    print("gpose ",blendshapes_data_encoded_gpose.shape)
    print("jaw ", blendshapes_data_encoded_jaw.shape)
    #print(blendshapes_data_encoded_eyelids.shape)

    blendshapes_data_encoded_expr    = torch.tensor(blendshapes_data_encoded_expr, dtype=torch.float32).to(device)
    blendshapes_data_encoded_gpose   = torch.tensor(blendshapes_data_encoded_gpose, dtype=torch.float32).to(device)
    blendshapes_data_encoded_jaw     = torch.tensor(blendshapes_data_encoded_jaw, dtype=torch.float32).to(device)
    #blendshapes_data_encoded_eyelids = torch.tensor(blendshapes_data_encoded_eyelids, dtype=torch.float32).to(device)
    blendshapes_data_encoded_eyelids = None

    #vertices_data_encoded = torch.tensor(vertices_data_encoded, dtype=torch.float32).to(device)

    # Compute vertices from blendshapes
    blendshapes_derived_vertices = get_vertices_from_blendshapes(blendshapes_data_encoded_expr,blendshapes_data_encoded_gpose, blendshapes_data_encoded_jaw, blendshapes_data_encoded_eyelids)
    #blendshapes_derived_vertices = vertices_data_encoded
    print("vertices ", blendshapes_derived_vertices.shape)
    
    # Fixed camera
    cam_original = torch.tensor([5,0,0], dtype=torch.float32).expand(blendshapes_derived_vertices.shape[0], -1).to(device)
    print("cam ", cam_original.shape)

    # Render the frames
    renderer_output_blendshapes = renderer.forward(blendshapes_derived_vertices, cam_original)
    renderer_output_blendshapes = renderer_output_blendshapes['rendered_img']

    #N = renderer_output_blendshapes['rendered_img'].shape[0] # Number of frames
    N = renderer_output_blendshapes.shape[0] # Number of frames

    # Create a figure with two subplots
    fig, axes = plt.subplots(1, 1, figsize=(5, 5),tight_layout=False)

    # Create an animation
    ani = animation.FuncAnimation(
                                    fig, 
                                    update, 
                                    frames=N, 
                                    fargs=(renderer_output_blendshapes, axes),
                                    interval=100
                                )

    # Save the animation as a video file
    video_file = f'{output_dir}/{base_name}.mp4'
    ani.save(video_file, writer='ffmpeg', fps=25)
    print(f"Video saved as {video_file}")
    
    # =============== Add audio to the video ===============
    
    # Add audio to the video
    audio_file = f'{audio_dir}/{base_name}.wav'
    output_with_audio = f'{output_dir}/{base_name}_with_audio.mp4'
    if os.path.exists(audio_file):
        cmd = f'ffmpeg -y -i {video_file} -i {audio_file} -c:v copy -c:a aac -strict experimental {output_with_audio}'
        subprocess.run(cmd, shell=True)
        print(f"Video with audio saved as {output_with_audio}")
    else:
        print(f"Audio file {audio_file} not found")


In [None]:
# Directory containing encoded files
encoded_dir = '/mnt/fasttalk/demo/output' #'/root/Datasets/ARTalk_data/converted_data/npz'
audio_dir   = '/mnt/fasttalk/demo/input'  #'/root/Datasets/ARTalk_data/converted_data/wav'
output_dir  = 'demo/video'

# Check if the directory exists, if not, create it
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Directory created: {output_dir}")
else:
    print(f"Directory already exists: {output_dir}")

counter = 20

file_list = os.listdir(encoded_dir)
random.shuffle(file_list)

# Iterate over all files in the encoded directory
for file_name in file_list:
    if counter == 0:
        break
    if file_name.endswith('.npz'):
        print(f"Processing file: {file_name}")
        create_and_save_video(encoded_dir,file_name,renderer,audio_dir,output_dir)
    counter -= 1