### Compile
and initialize args

In [None]:
import time

import os
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
import tyro
import subprocess
from rich.progress import track
import torchvision
import cv2
import threading
import queue
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob
import os
import numpy as np
import time
import torch
import imageio

from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig

def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

args = ArgumentConfig()
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
# print("inference_cfg: ", inference_cfg)
# print("crop_cfg: ", crop_cfg)
device = 'cuda'
print("Compile complete")

### Initialize util functions

Import util functions

In [None]:
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.cropper import Cropper
from src.utils.camera import get_rotation_matrix
from src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from src.utils.crop import _transform_img, prepare_paste_back, paste_back
from src.utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from src.utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from src.utils.filter import smooth

Declare several models

In [None]:
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, device, 'appearance_feature_extractor')
# init M
motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, device, 'motion_extractor')
# init W
warping_module = load_model(inference_cfg.checkpoint_W, model_config, device, 'warping_module')
# init G
spade_generator = load_model(inference_cfg.checkpoint_G, model_config, device, 'spade_generator')
# init S and R
if inference_cfg.checkpoint_S is not None and os.path.exists(inference_cfg.checkpoint_S):
    stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, device, 'stitching_retargeting_module')
else:
    stitching_retargeting_module = None

cropper = Cropper(crop_cfg=crop_cfg, device=device)

In [None]:
import numpy as np


def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
    return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
            (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))


def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
    lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
    righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
    if target_eye_ratio is not None:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
    else:
        return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)


def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
    return calculate_distance_ratio(lmk, 90, 102, 48, 66)

def calc_ratio(lmk_lst):
    input_eye_ratio_lst = []
    input_lip_ratio_lst = []
    for lmk in lmk_lst:
        # for eyes retargeting
        input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
        # for lip retargeting
        input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
    return input_eye_ratio_lst, input_lip_ratio_lst

def prepare_videos(imgs, device) -> torch.Tensor:
    """ construct the input as standard
    imgs: NxBxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)[..., np.newaxis]  # TxHxWx3x1
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    y = _imgs.astype(np.float32) / 255.
    y = np.clip(y, 0, 1)  # clip to 0~1
    y = torch.from_numpy(y).permute(0, 4, 3, 1, 2)  # TxHxWx3x1 -> Tx1x3xHxW
    y = y.to(device)

    return y


loading single vid or dir of vids

In [None]:
def prepare_videos_(imgs, device):
    """ construct the input as standard
    imgs: NxHxWx3, uint8
    """
    if isinstance(imgs, list):
        _imgs = np.array(imgs)
    elif isinstance(imgs, np.ndarray):
        _imgs = imgs
    else:
        raise ValueError(f'imgs type error: {type(imgs)}')

    # y = _imgs.astype(np.float32) / 255.
    y = _imgs
    y = torch.from_numpy(y).permute(0, 3, 1, 2)  # NxHxWx3 -> Nx3xHxW
    y = y.to(device)
    y = y / 255.
    y = torch.clamp(y, 0, 1)

    return y

def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (256, 256))  # Resize to 256x256
        frames.append(frame)

    cap.release()
    return video_path, frames

def read_multiple_videos(video_paths, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        results = list(executor.map(read_video_frames, video_paths))
    return results

def process_videos(input_path, num_threads=4):
    if os.path.isdir(input_path):
        video_paths = sorted(glob.glob(os.path.join(input_path, '*.mp4')))  # Sort to ensure consistent order
        print(f"Found {len(video_paths)} video files.")
        video_frames = read_multiple_videos(video_paths, num_threads)
    else:
        print(f"Processing single video file: {input_path}")
        video_frames = [read_video_frames(input_path)]

    all_frames = []
    total_frames = 0
    video_lengths = []

    for video_path, frames in video_frames:
        all_frames.extend(frames)
        frame_count = len(frames)
        total_frames += frame_count
        video_lengths.append(frame_count)
        print(f"Processed video: {video_path}, frames: {frame_count}")

    print(f"\nTotal frames across all videos: {total_frames}")
    print(f"Video lengths: {video_lengths}")

    # Convert to numpy array
    all_frames = np.array(all_frames)

    print(f"Shape of concatenated array: {all_frames.shape}")
    return all_frames, video_lengths


In [None]:
import os
import subprocess

def get_audio_path(video_path):
    audio_filename = f"extracted_audio_{os.path.basename(video_path).split('.')[0]}.wav"
    audio_path = os.path.join(os.getcwd(), audio_filename)
    return audio_path

def extract_audio(video_path):
    # Generate a unique filename for the audio in the current directory
    audio_filename = f"extracted_audio_{os.path.basename(video_path).split('.')[0]}.wav"
    audio_path = os.path.join(os.getcwd(), audio_filename)

    # Use ffmpeg to extract audio
    try:
        subprocess.run(['ffmpeg', '-i', video_path, '-q:a', '0', '-map', 'a', audio_path], check=True)
        print(f"Audio extracted successfully: {audio_path}")
        return audio_path
    except subprocess.CalledProcessError as e:
        print(f"Error extracting audio: {e}")
        return None

# Example usage:
# video_path = "path/to/your/video.mp4"
# audio_path = extract_audio(video_path)
# if audio_path:
#     print(f"Audio saved to: {audio_path}")
# else:
#     print("Failed to extract audio")


Motion Extractor

In [None]:
def get_kp_info(x: torch.Tensor, **kwargs) -> dict:
    """ get the implicit keypoint information
    x: Bx3xHxW, normalized to 0~1
    flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
    return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        kp_info = motion_extractor(x)

        if inference_cfg.flag_use_half_precision:
            # float the dict
            for k, v in kp_info.items():
                if isinstance(v, torch.Tensor):
                    kp_info[k] = v.float()

    flag_refine_info: bool = kwargs.get('flag_refine_info', True)
    if flag_refine_info:
        bs = kp_info['kp'].shape[0]
        kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None]  # Bx1
        kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None]  # Bx1
        kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None]  # Bx1
        kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3)  # BxNx3
        kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_info

def process_driving_video(I_d_lst):
    n_frames = I_d_lst.shape[0]
    template_dct = {
        'n_frames': n_frames,
        'output_fps': 25,
        'motion': [],
        'c_d_eyes_lst': [],
        'c_d_lip_lst': [],
        'x_i_info_lst': [],
    }

    for i in range(n_frames):
        # collect s, R, δ and t for inference
        I_i = I_d_lst[i]
        x_i_info = get_kp_info(I_i)
        R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])

        item_dct = {
            'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
            'R': R_i.cpu().numpy().astype(np.float32),
            'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
            't': x_i_info['t'].cpu().numpy().astype(np.float32),
        }

        template_dct['motion'].append(item_dct)

        # c_eyes = c_d_eyes_lst[i].astype(np.float32)
        # template_dct['c_d_eyes_lst'].append(c_eyes)

        # c_lip = c_d_lip_lst[i].astype(np.float32)
        # template_dct['c_d_lip_lst'].append(c_lip)

        template_dct['x_i_info_lst'].append(x_i_info)
        print(f'frame {i} done')

    return template_dct

Source image extraction

In [None]:
def prepare_source(img: np.ndarray) -> torch.Tensor:
    """ construct the input as standard
    img: HxWx3, uint8, 256x256
    """
    h, w = img.shape[:2]
    x = img.copy()

    if x.ndim == 3:
        x = x[np.newaxis].astype(np.float32) / 255.  # HxWx3 -> 1xHxWx3, normalized to 0~1
    elif x.ndim == 4:
        x = x.astype(np.float32) / 255.  # BxHxWx3, normalized to 0~1
    else:
        raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
    x = np.clip(x, 0, 1)  # clip to 0~1
    x = torch.from_numpy(x).permute(0, 3, 1, 2)  # 1xHxWx3 -> 1x3xHxW
    x = x.to(device)
    return x

def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
    """ get the image after the warping of the implicit keypoints
    feature_3d: Bx32x16x64x64, feature volume
    kp_source: BxNx3
    kp_driving: BxNx3
    """
    # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)）
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        # get decoder input
        ret_dct = warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
        # decode
        ret_dct['out'] = spade_generator(feature=ret_dct['out'])

    return ret_dct

def extract_feature_3d( x: torch.Tensor) -> torch.Tensor:
    """ get the appearance feature of the image by F
    x: Bx3xHxW, normalized to 0~1
    """
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
                                 enabled=inference_cfg.flag_use_half_precision):
        feature_3d = appearance_feature_extractor(x)

    return feature_3d.float()

def transform_keypoint(kp_info: dict):
    """
    transform the implicit keypoints with the pose, shift, and expression deformation
    kp: BxNx3
    """
    kp = kp_info['kp']    # (bs, k, 3)
    pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']

    t, exp = kp_info['t'], kp_info['exp']
    scale = kp_info['scale']

    pitch = headpose_pred_to_degree(pitch)
    yaw = headpose_pred_to_degree(yaw)
    roll = headpose_pred_to_degree(roll)

    bs = kp.shape[0]
    if kp.ndim == 2:
        num_kp = kp.shape[1] // 3  # Bx(num_kpx3)
    else:
        num_kp = kp.shape[1]  # Bxnum_kpx3

    rot_mat = get_rotation_matrix(pitch, yaw, roll)    # (bs, 3, 3)

    # Eqn.2: s * (R * x_c,s + exp) + t
    kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
    kp_transformed *= scale[..., None]  # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
    kp_transformed[:, :, 0:2] += t[:, None, 0:2]  # remove z, only apply tx ty

    return kp_transformed

def parse_output(out: torch.Tensor) -> np.ndarray:
    """ construct the output as standard
    return: 1xHxWx3, uint8
    """
    out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1])  # 1x3xHxW -> 1xHxWx3
    out = np.clip(out, 0, 1)  # clip to 0~1
    out = np.clip(out * 255, 0, 255).astype(np.uint8)  # 0~1 -> 0~255

    return out

### Demo pipeline

In [None]:
# inputs
input_vid_path = '/mnt/e/data/diffposetalk_data/TFHP_raw/crop/TH_00203/000.mp4'  # Can be a directory or a single video file
input_src_path = '/mnt/c/Users/mjh/Downloads/live_in/t3.jpg'
input_audio_path = get_audio_path(input_vid_path)
# input_audio_path = '/mnt/c/Users/mjh/Downloads/live_in/i5.wav'

# Read video frames
all_frames, video_lengths = process_videos(input_vid_path)
driving_rgb_lst = all_frames
I_d_lst = prepare_videos_(driving_rgb_lst, device)
I_d_lst = I_d_lst.unsqueeze(1)
I_d_lst = I_d_lst[0 : video_lengths[0]]
print(f"Shape of driving video: {I_d_lst.shape}")
# read audio if exists
audio_path = None
if has_audio_stream(input_vid_path) and not os.path.exists(input_audio_path):
    audio_path = extract_audio(input_vid_path)  # Extract audio from the video
elif os.path.exists(input_audio_path):
    audio_path = input_audio_path
else:
    raise ValueError("No audio stream found in the video and no audio file provided.")

# Extract motion information
template_dct = process_driving_video(I_d_lst)

# Load source image
img_rgb = load_image_rgb(input_src_path)
source_rgb_lst = [img_rgb]

source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256))  # force to resize to 256x256

# extract the src implicit keypoint information
I_s = prepare_source(img_crop_256x256)
x_s_info = get_kp_info(I_s)
x_c_s = x_s_info['kp']
x_s = transform_keypoint(x_s_info)
f_s = extract_feature_3d(I_s)


#### Frontalize


In [None]:
# R = template_dct['motion'][0]['R']
# exp = template_dct['motion'][0]['exp']
# t = template_dct['motion'][0]['t']
# scale = template_dct['motion'][0]['scale']
# # print dims
# print(R.shape, exp.shape, t.shape, scale.shape)
# # print flatten dims
# print(R.flatten().shape, exp.flatten().shape, t.flatten().shape, scale.flatten().shape)
# # print range
# print(R.min(), R.max(), exp.min(), exp.max(), t.min(), t.max(), scale.min(), scale.max())

In [None]:
# for i in range(n_frames):
#     R = template_dct['motion'][i]['R']
#     exp = template_dct['motion'][i]['exp']
#     t = template_dct['motion'][i]['t']
#     scale = template_dct['motion'][i]['scale']
#     info = template_dct['x_i_info_lst'][i]
#     roll, pitch, yaw = info['roll'], info['pitch'], info['yaw']

#     new_R = get_rotation_matrix(pitch, yaw, roll)

In [None]:
import torch

def angular_distance(pose1, pose2):
    diff = torch.abs(pose1 - pose2)
    diff = torch.min(diff, 2*torch.pi - diff)
    return torch.norm(diff)

def find_dominant_pose(poses):
    N = poses.shape[0]
    total_distances = torch.zeros(N, device=poses.device)
    for i in range(N):
        distances = angular_distance(poses[i].unsqueeze(0), poses)
        total_distances[i] = torch.sum(distances)
    min_distance_index = torch.argmin(total_distances)
    return poses[min_distance_index], min_distance_index

# Prepare data
n_frames = len(template_dct['motion'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Collect all poses and t values
all_poses = torch.zeros(n_frames, 3, device=device)
all_t = torch.zeros(n_frames, 3, device=device)

for i in range(n_frames):
    info = template_dct['x_i_info_lst'][i]
    roll, pitch, yaw = info['roll'], info['pitch'], info['yaw']
    all_poses[i] = torch.tensor([roll, pitch, yaw], device=device).squeeze()
    all_t[i] = torch.tensor(template_dct['motion'][i]['t'], device=device)

# Find dominant pose
dominant_pose, _ = find_dominant_pose(all_poses)

# Find median t
median_t = torch.median(all_t, dim=0).values

# Subtract dominant pose and median t from the sequence
for i in range(n_frames):
    # Update pose
    template_dct['x_i_info_lst'][i]['roll'] = (all_poses[i, 0]  - 1 * dominant_pose[0]).unsqueeze(0)
    template_dct['x_i_info_lst'][i]['pitch'] = (all_poses[i, 1] - 1 * dominant_pose[1]).unsqueeze(0)
    template_dct['x_i_info_lst'][i]['yaw'] = (all_poses[i, 2]   - 1 * dominant_pose[2]).unsqueeze(0)

    # Update t
    template_dct['motion'][i]['t'] = (all_t[i] - median_t).cpu().numpy()

    # Recalculate R with the updated pose
    new_R = get_rotation_matrix(
        template_dct['x_i_info_lst'][i]['pitch'],
        template_dct['x_i_info_lst'][i]['yaw'],
        template_dct['x_i_info_lst'][i]['roll']
    )
    template_dct['motion'][i]['R'] = new_R.cpu().numpy()

print(f"Dominant pose (roll, pitch, yaw): {dominant_pose.cpu().numpy()}")
print(f"Median t: {median_t.cpu().numpy()}")

Single frame retarget

In [None]:
# R = template_dct['motion'][0]['R']
# exp = template_dct['motion'][0]['exp']
# t = template_dct['motion'][0]['t']
# scale = template_dct['motion'][0]['scale']

# scale_tensor = torch.tensor(scale, device=device)
# R_tensor = torch.tensor(R, device=device)
# exp_tensor = torch.tensor(exp, device=device)
# t_tensor = torch.tensor(t, device=device)
# print(scale_tensor.shape, R_tensor.shape, exp_tensor.shape, t_tensor.shape)

# start = time.time()
# x_d_i_new = scale_tensor * (x_c_s @ R_tensor + exp_tensor) + t_tensor

# # x_d_i_new = scale * (x_c_s @ R + exp) + t
# out = warp_decode(f_s, x_s, x_d_i_new)
# # print(out)
# # I_p_i = parse_output(out['out'])[0]
# end_time = time.time() - start
# print(f'warp_decode time: {end_time}')

Large chunk of frames generator. Performance testing

In [None]:
template_dct['motion'][0]['exp'].shape

In [None]:
# Reshape exp to (63, num_frames) and calculate average across frames
exp_reshaped = torch.tensor([frame['exp'] for frame in template_dct['motion']]).T.to(device)
exp_reshaped = exp_reshaped.permute(3, 2, 1, 0).unsqueeze(1).reshape(-1, 63)
print(exp_reshaped.shape)
# Calculate average of each feature in the last dimension
exp_avg = torch.mean(exp_reshaped, dim=0)

# Print the average for each feature
print("Average of each feature, shape: ", exp_avg.shape)
for i, avg in enumerate(exp_avg):
    print(f"Feature {i}: {avg.item():.4f}")


In [16]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

correct_indices = [
     4, 6, 7, 22, 33, 34, 40, 43, 45, 46, 48, 51, 52, 53, 57, 58, 59, 60, 61, 62
] # deleted 49,
incorrect_indices = [i for i in range(63) if i not in correct_indices]

bool_mask = torch.zeros(63, device=device)
bool_mask[correct_indices] = True
correct_indices = torch.tensor(correct_indices, device=device)
incorrect_indices = torch.tensor(incorrect_indices, device=device)


In [None]:
len(correct_indices)

In [None]:
import cv2
import time
import os
import subprocess

def generate_frames(template_dct, x_c_s, f_s, x_s, device, show_cv=False):
    total_frames = len(template_dct['motion'])
    if show_cv:
        cv2.namedWindow('Processed Frame', cv2.WINDOW_NORMAL)
        cv2.resizeWindow('Processed Frame', 512, 512)  # Adjust size as needed

    # useful cache
    t_identity = torch.zeros((1, 3), dtype=torch.float32, device=device)
    pitch_identity = torch.zeros((1), dtype=torch.float32, device=device)
    yaw_identity = torch.zeros((1), dtype=torch.float32, device=device)
    roll_identity = torch.zeros((1), dtype=torch.float32, device=device)
    scale_identity = torch.ones((1), dtype=torch.float32, device=device) * 1.3
    use_identity_pose = False
    # mask to use
    # bool_mask, correct_indices, incorrect_indices = get_latent_mask()
    for frame_index in range(total_frames):
        exp = template_dct['motion'][frame_index]['exp']

        if not use_identity_pose:
            R = template_dct['motion'][frame_index]['R']
            t = template_dct['motion'][frame_index]['t']
            # scale = template_dct['motion'][frame_index]['scale'] * 1.3
            scale = scale_identity
        else:
            R = get_rotation_matrix(pitch_identity, yaw_identity, roll_identity)
            t = t_identity
            scale = scale_identity

        # Convert to tensors
        scale_tensor = torch.tensor(scale, device=device)
        R_tensor = torch.tensor(R, device=device)
        exp_tensor = torch.tensor(exp, device=device)
        exp_tensor = exp_tensor.reshape(-1, 63)
        # exp_tensor[:, incorrect_indices] = 0
        exp_tensor = exp_tensor.reshape(-1, 21, 3)
        t_tensor = torch.tensor(t, device=device)

        # Process the frame
        x_d_i_new = scale_tensor * (x_c_s @ R_tensor + exp_tensor) + t_tensor
        out = warp_decode(f_s, x_s, x_d_i_new)

        # Convert tensor to numpy array and rescale to 0-255 range
        img_np = (out['out'][0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

        if show_cv:
            # Convert from RGB to BGR for cv2
            img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            # Display the frame
            cv2.imshow('Processed Frame', img_bgr)
            # Print progress
            print(f"Processed frame {frame_index+1}/{total_frames}")
            # Wait for a short time and check for 'q' key to quit
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        yield frame_index, img_np

    if show_cv:
        cv2.destroyAllWindows()

def display_frames(frame_generator, display_option='opencv'):
    if display_option == 'opencv':
        cv2.namedWindow('Processed Frame', cv2.WINDOW_NORMAL)
        cv2.resizeWindow('Processed Frame', 512, 512)  # Adjust size as needed

    for frame_index, img_np in frame_generator:
        if display_option == 'opencv':
            # Convert from RGB to BGR for cv2
            img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

            # Display the frame
            cv2.imshow('Processed Frame', img_bgr)

            # Print progress
            print(f"Processed frame {frame_index+1}/{len(template_dct['motion'])}")

            # Wait for a short time and check for 'q' key to quit
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        # Optional: add a small delay to make the display more visible
        time.sleep(0.01)  # Adjust as needed

    if display_option == 'opencv':
        cv2.destroyAllWindows()

    print("Processing complete.")

def save_video(frame_generator, audio_path, output_video='video_driven_output.mp4', fps=25):
    assert os.path.exists(audio_path), f"Audio file not found: {audio_path}"
    output_no_audio_path = 'video_driven_no_audio.mp4'

    # Remove the files if they exist
    if os.path.exists(output_no_audio_path):
        os.remove(output_no_audio_path)
    if os.path.exists(output_video):
        os.remove(output_video)

    # Get the first frame to determine video dimensions
    _, first_frame = next(frame_generator)
    height, width, layers = first_frame.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(output_no_audio_path, fourcc, fps, (width, height))

    # Write the first frame
    video.write(cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))

    # Write the rest of the frames
    for _, frame in frame_generator:
        video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    video.release()

    # Add audio to the video using ffmpeg
    ffmpeg_cmd = [
        'ffmpeg',
        '-i', output_no_audio_path,
        '-i', audio_path,
        '-c:v', 'copy',
        '-c:a', 'aac',
        '-shortest',
        output_video
    ]

    try:
        subprocess.run(ffmpeg_cmd, check=True)
        os.remove(output_no_audio_path)
        print(f"Video with audio saved to {output_video}")
    except subprocess.CalledProcessError as e:
        print(f"Error adding audio to video: {e}")

# Generate frames and save video
frame_generator = generate_frames(template_dct, x_c_s, f_s, x_s, device)
save_video(frame_generator, audio_path)  # Assuming audio_path is defined