In [None]:
#!/usr/bin/env python
import sys
import os

# change base folder
os.chdir('../')

import torch
import cv2
import numpy as np
from flame.flame import FlameHead
from renderer.renderer import Renderer
import argparse
import torch.nn.functional as F
from pytorch3d.transforms import matrix_to_euler_angles
import subprocess
import tempfile 
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import librosa
import soundfile as sf
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import wandb
import glob
from models.stage2 import CodeTalker
import yaml
from models import get_model
from base.baseTrainer import load_state_dict
from types import SimpleNamespace
from transformers import AutoProcessor, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
import pickle
import itertools

In [None]:
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flame    = FlameHead(shape_params=300,expr_params=50).to(device)
renderer = Renderer(render_full_head=True).to(device)

In [None]:
def get_vertices_from_blendshapes(expr, jaw, neck=None):
    # Load the encoded file
    expr_tensor =  expr.to(device)
    jaw_tensor  =  jaw.to(device) #torch.zeros(expr_tensor.shape[0],3).to(device)

    target_shape_tensor = torch.zeros(expr_tensor.shape[0], 300).expand(expr_tensor.shape[0], -1).to(device)

    I = matrix_to_euler_angles(torch.cat([torch.eye(3)[None]], dim=0),"XYZ").to(device)

    eye_r    = I.clone().to(device).squeeze()
    eye_l    = I.clone().to(device).squeeze()
    eyes     = torch.cat([eye_r,eye_l],dim=0).expand(expr_tensor.shape[0], -1).to(device)

    translation = torch.zeros(expr_tensor.shape[0], 3).to(device)

    if neck==None:
        neck = I.clone().expand(expr_tensor.shape[0], -1).to(device)
    
    rotation = I.clone().expand(expr_tensor.shape[0], -1).to(device)

    # Compute Flame
    flame_output_only_shape   = flame.forward(target_shape_tensor, expr_tensor, rotation, neck, jaw_tensor, eyes, translation, return_landmarks=False)

    return flame_output_only_shape.detach()

In [None]:
def update(frame_inx, renderer_output_blendshapes, axes):
    # Select the frames to plot
    frame = renderer_output_blendshapes['rendered_img'][frame_inx].detach().cpu().numpy().transpose(1, 2, 0)

    # Update the second subplot
    axes.clear()
    axes.imshow((frame * 255).astype(np.uint8))
    axes.axis('off')
    axes.set_title(f'Frame Stage 1 (Blendshape) {frame_inx + 1}')

def lowpass_filter(tensor, kernel_size=10):
    # tensor: [seq_len, 3]
    tensor = tensor.unsqueeze(0).transpose(1,2)  # -> [1, 3, seq_len]
    
    # Create a 1D uniform kernel
    kernel = torch.ones(1, 1, kernel_size, device=tensor.device) / kernel_size
    
    # Apply same kernel to each channel (grouped conv)
    filtered = F.conv1d(tensor, kernel.expand(3, -1, -1), padding=kernel_size//2, groups=3)
    
    return filtered.transpose(1,2).squeeze(0)  # -> back to [seq_len, 3]

# Function to create and save the video
def create_and_save_video(encoded_dir,file_name, renderer,audio_dir,output_dir):
    base_name = os.path.basename(file_name).replace('.npz', '')
    print(base_name)
    
    blendshapes_data_encoded_expr = np.load(f'{encoded_dir}/{base_name}.npz')['exp'].reshape(-1, 50)
    blendshapes_data_encoded_jaw  = np.load(f'{encoded_dir}/{base_name}.npz')['pose'][:,3:6].reshape(-1, 3)
    blendshapes_data_encoded_neck  = np.load(f'{encoded_dir}/{base_name}.npz')["pose"][:,0:3].reshape(-1, 3)

    blendshapes_data_encoded_expr = torch.tensor(blendshapes_data_encoded_expr, dtype=torch.float32).to(device)
    blendshapes_data_encoded_jaw  = torch.tensor(blendshapes_data_encoded_jaw, dtype=torch.float32).to(device)
    blendshapes_data_encoded_neck = torch.tensor(blendshapes_data_encoded_neck, dtype=torch.float32).to(device)

    # Flip neck horizontally to avoid bias
    blendshapes_data_encoded_neck_flipped = blendshapes_data_encoded_neck.clone() 
    blendshapes_data_encoded_neck_flipped[:, [1, 2]] = blendshapes_data_encoded_neck[:, [2, 1]]
    blendshapes_data_encoded_neck_flipped = lowpass_filter(blendshapes_data_encoded_neck_flipped, kernel_size=9)

    # Compute mean across seq (dim 0), keeping the 3 features
    mean_per_component = blendshapes_data_encoded_neck_flipped.mean(dim=0)  # [3]

    # Subtract mean from every element
    blendshapes_data_encoded_neck_flipped = blendshapes_data_encoded_neck_flipped - mean_per_component

    print(blendshapes_data_encoded_expr.shape)
    print(blendshapes_data_encoded_jaw.shape)   
    print(blendshapes_data_encoded_neck_flipped.shape)
    print("----")

    # Compute vertices from blendshapes
    blendshapes_derived_vertices = get_vertices_from_blendshapes(blendshapes_data_encoded_expr,blendshapes_data_encoded_jaw, blendshapes_data_encoded_neck_flipped)
    print(blendshapes_derived_vertices.shape)
    
    # Fixed camera
    cam_original = torch.tensor([10,0,0], dtype=torch.float32).expand(blendshapes_derived_vertices.shape[0], -1).to(device)
    print(cam_original.shape)

    # Render the frames
    renderer_output_blendshapes  = renderer.forward(blendshapes_derived_vertices, cam_original)

    N = renderer_output_blendshapes['rendered_img'].shape[0] # Number of frames

    # Create a figure with two subplots
    fig, axes = plt.subplots(1, 1, figsize=(5, 5))

    # Create an animation
    ani = animation.FuncAnimation(
        fig, 
        update, 
        frames=N, 
        fargs=(renderer_output_blendshapes, axes),
        interval=100
    )

    # Save the animation as a video file
    video_file = f'{output_dir}/{base_name}.mp4'
    ani.save(video_file, writer='ffmpeg', fps=25)
    print(f"Video saved as {video_file}")
    
    # =============== Add audio to the video ===============
    
    # Add audio to the video
    audio_file = f'{audio_dir}/{base_name}.wav'
    output_with_audio = f'{output_dir}/{base_name}_with_audio.mp4'
    if os.path.exists(audio_file):
        cmd = f'ffmpeg -y -i {video_file} -i {audio_file} -c:v copy -c:a aac -strict experimental {output_with_audio}'
        subprocess.run(cmd, shell=True)
        print(f"Video with audio saved as {output_with_audio}")
    else:
        print(f"Audio file {audio_file} not found")


In [None]:
# Directory containing encoded files
encoded_dir = '/root/Datasets/ensemble_dataset/npz'
audio_dir   = '/root/Datasets/ensemble_dataset/wav'
output_dir  = 'demo/video'

# Check if the directory exists, if not, create it
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Directory created: {output_dir}")
else:
    print(f"Directory already exists: {output_dir}")

counter = 1
# Iterate over all files in the encoded directory
for file_name in os.listdir(encoded_dir):
    if counter == 0:
        break
    if file_name.endswith('.npz'):
        print(f"Processing file: {file_name}")
        create_and_save_video(encoded_dir,file_name,renderer,audio_dir,output_dir)
    counter -= 1