In [None]:
data_path = "/mnt/Datasets/expressive_ft/npz"

In [None]:
import os
import glob
import numpy as np

npz_files = sorted(glob.glob(os.path.join(data_path, "*.npz")))
if not npz_files:
    print(f"No .npz files found in {data_path}")
else:
    for npz_path in npz_files:
        with np.load(npz_path) as data:
            print(f"{os.path.basename(npz_path)} keys: {list(data.keys())}")

In [None]:
tensor = np.load(npz_path) 

In [None]:
tensor['jaw'].shape

In [None]:
import os
os.chdir("/mnt/fasttalk")
import glob
import shutil
import subprocess
import numpy as np
import torch
import yaml
import librosa
from types import SimpleNamespace
from base.baseTrainer import load_state_dict
from transformers import Wav2Vec2FeatureExtractor

from renderer.renderer import Renderer
from flame_model.FLAME import FLAMEModel
from pytorch3d.transforms import rotation_6d_to_matrix, matrix_to_euler_angles
from models import get_model

audio_dir = "/mnt/Datasets/expressive_ft/wav"
real_video_dir = "/mnt/Datasets/expressive_ft/synthetic_dataset"
output_dir = "/mnt/fasttalk/demo/video_expressive_ft"
os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
renderer = Renderer(render_full_head=True).to(device)
flame = FLAMEModel(n_shape=300, n_exp=50).to(device)

POSE_IS_6D = True
CHUNK_SIZE = 50  # reduce if still OOM
FPS = 24
EPS = 1e-8

# ffmpeg speed knobs
FFMPEG_PRESET = "fast"
FFMPEG_CRF = "18"

def load_and_flatten_yaml(config_path):
    with open(config_path, "r") as f:
        full_config = yaml.safe_load(f)
    flattened_config = {}
    for top_level_key, sub_dict in full_config.items():
        if isinstance(sub_dict, dict):
            for k, v in sub_dict.items():
                flattened_config[k] = v
        else:
            flattened_config[top_level_key] = sub_dict
    return SimpleNamespace(**flattened_config)

def load_model_for_eval(checkpoint_path, cfg):
    model = get_model(cfg)
    model = model.to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)
    model.eval()
    return model

def load_model_for_eval_state_dict(checkpoint_path, cfg):
    model = get_model(cfg)
    model = model.to(device)
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage.cpu())
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        load_state_dict(model, checkpoint["state_dict"], strict=False)
    else:
        load_state_dict(model, checkpoint, strict=False)
    model.eval()
    return model

cfg_s1 = load_and_flatten_yaml("/mnt/fasttalk/config/talkinghead-1kh/stage1.yaml")
cfg_s1.batch_size = 1
checkpoint_path_s1 = "/mnt/fasttalk/logs/talkinghead/talkinghead-s1/model_200/model.pth.tar"
vq_model = load_model_for_eval(checkpoint_path_s1, cfg_s1)

cfg_s2 = load_and_flatten_yaml("/mnt/fasttalk/config/talkinghead-1kh/stage2finetunning.yaml")
cfg_s2.batch_size = 1
base_checkpoint_path_s2 = "/mnt/fasttalk/logs/talkinghead/talkinghead-s2/model_260/model.pth.tar"
finetuned_checkpoint_path_s2 = "/mnt/fasttalk/logs/talkinghead/talkinghead-s2-finetunning/model_20/model.pth.tar"

# Load pretrained S2 first, initialize residual base, then load finetuned weights
s2_model = get_model(cfg_s2).to(device)
checkpoint_base = torch.load(base_checkpoint_path_s2, map_location=lambda storage, loc: storage.cpu())
if isinstance(checkpoint_base, dict) and "state_dict" in checkpoint_base:
    load_state_dict(s2_model, checkpoint_base["state_dict"], strict=False)
else:
    load_state_dict(s2_model, checkpoint_base, strict=False)
s2_model.eval()
if getattr(cfg_s2, "use_residual", False):
    s2_model.init_residual_base()

checkpoint_ft = torch.load(finetuned_checkpoint_path_s2, map_location=lambda storage, loc: storage.cpu())
if isinstance(checkpoint_ft, dict) and "state_dict" in checkpoint_ft:
    load_state_dict(s2_model, checkpoint_ft["state_dict"], strict=False)
else:
    load_state_dict(s2_model, checkpoint_ft, strict=False)
s2_model.eval()

wav2vec_processor = Wav2Vec2FeatureExtractor.from_pretrained(cfg_s2.wav2vec2model_path)

def to_euler_xyz(pose_tensor):
    if pose_tensor.shape[-1] == 3:
        return pose_tensor
    if pose_tensor.shape[-1] == 6:
        if not POSE_IS_6D:
            return pose_tensor[:, :3]
        rot_mats = rotation_6d_to_matrix(pose_tensor)
        return matrix_to_euler_angles(rot_mats, "XYZ")
    if pose_tensor.shape[-1] == 9:
        rot_mats = pose_tensor.view(-1, 3, 3)
        return matrix_to_euler_angles(rot_mats, "XYZ")
    raise ValueError(f"Unsupported pose shape: {pose_tensor.shape}")

def get_vertices_from_blendshapes(expr_tensor, gpose_tensor, jaw_tensor, eyelids_tensor, shape_tensor):
    # Same mechanism as test_talkinghead-1kh_vq_bs.ipynb, now with shape
    target_shape_tensor = shape_tensor
    eye = matrix_to_euler_angles(torch.eye(3)[None].to(expr_tensor.device), "XYZ").squeeze(0)
    eyes = torch.cat([eye, eye], dim=0).unsqueeze(0).expand(expr_tensor.shape[0], -1)
    pose = torch.cat([gpose_tensor, jaw_tensor], dim=-1)
    verts, _ = flame.forward(
        shape_params=target_shape_tensor,
        expression_params=expr_tensor,
        pose_params=pose,
        eye_pose_params=eyes
    )
    return verts.detach()

def extract_blendshapes_from_npz(npz_data):
    # exp
    if "exp" in npz_data.files:
        exp = npz_data["exp"]
    elif "expression_params" in npz_data.files:
        exp = npz_data["expression_params"]
    else:
        return None
    # shape
    shape = None
    if "shape" in npz_data.files:
        shape = npz_data["shape"]
    elif "shape_params" in npz_data.files:
        shape = npz_data["shape_params"]
    elif "beta" in npz_data.files:
        shape = npz_data["beta"]
    # pose / gpose / jaw
    pose = None
    if "pose" in npz_data.files:
        pose = npz_data["pose"]
    elif "pose_params" in npz_data.files:
        pose = npz_data["pose_params"]
    elif "gpose" in npz_data.files:
        pose = npz_data["gpose"]
    jaw = None
    if "jaw" in npz_data.files:
        jaw = npz_data["jaw"]
    elif "jaw_params" in npz_data.files:
        jaw = npz_data["jaw_params"]

    exp_t = torch.from_numpy(exp.reshape(-1, 50)).float().to(device)
    T = exp_t.shape[0]
    if shape is None:
        shape_t = torch.zeros((T, 300), device=device)
    else:
        shape_arr = shape.reshape(-1, shape.shape[-1])
        shape_t = torch.from_numpy(shape_arr).float().to(device)
        if shape_t.shape[0] == 1 and T > 1:
            shape_t = shape_t.expand(T, -1)
        elif shape_t.shape[0] != T:
            shape_t = shape_t[:T] if shape_t.shape[0] > T else shape_t.expand(T, -1)
        if shape_t.shape[-1] < 300:
            pad = torch.zeros((shape_t.shape[0], 300 - shape_t.shape[-1]), device=device)
            shape_t = torch.cat([shape_t, pad], dim=-1)
        elif shape_t.shape[-1] > 300:
            shape_t = shape_t[:, :300]
    if pose is not None:
        pose_t = torch.from_numpy(pose.reshape(-1, pose.shape[-1])).float().to(device)
    else:
        pose_t = torch.zeros((T, 3), device=device)

    # Always treat pose as global pose (convert 6D -> 3D Euler)
    gpose_t = to_euler_xyz(pose_t)
    if jaw is None:
        jaw_t = torch.zeros((gpose_t.shape[0], 3), device=device)
    else:
        jaw_t = torch.from_numpy(jaw.reshape(-1, jaw.shape[-1])).float().to(device)
        jaw_t = to_euler_xyz(jaw_t)

    # eyelids
    if "eyelids" in npz_data.files:
        eyelids = npz_data["eyelids"]
        eyelids_t = torch.from_numpy(eyelids.reshape(-1, eyelids.shape[-1])).float().to(device)
    else:
        eyelids_t = torch.ones((gpose_t.shape[0], 2), device=device)

    blendshapes = torch.cat([exp_t, gpose_t, jaw_t, eyelids_t], dim=-1)
    return blendshapes, shape_t

def load_style_tensor(style_path):
    try:
        loaded = torch.load(style_path, map_location=device)
        if torch.is_tensor(loaded):
            style_t = loaded.float().to(device)
        elif isinstance(loaded, np.ndarray):
            style_t = torch.from_numpy(loaded).float().to(device)
        else:
            style_t = None
    except Exception:
        style_t = None
    if style_t is None:
        npz_data = np.load(style_path, allow_pickle=True)
        if isinstance(npz_data, np.lib.npyio.NpzFile):
            extracted = extract_blendshapes_from_npz(npz_data)
            if extracted is None:
                key = npz_data.files[0]
                style_t = torch.from_numpy(npz_data[key]).float().to(device)
            else:
                style_t = extracted[0]
        else:
            style_t = torch.from_numpy(npz_data).float().to(device)
    if style_t.dim() == 2:
        style_t = style_t.unsqueeze(0)
    return style_t

def find_audio_for_stem(stem):
    for ext in (".wav", ".mp3", ".flac", ".m4a", ".aac", ".ogg"):
        candidate = os.path.join(audio_dir, stem + ext)
        if os.path.exists(candidate):
            return candidate
    matches = sorted(glob.glob(os.path.join(audio_dir, stem + ".*")))
    return matches[0] if matches else None

def find_real_video_for_stem(stem):
    # direct match in root
    for ext in (".mp4", ".mov", ".mkv", ".avi", ".webm"):
        candidate = os.path.join(real_video_dir, stem + ext)
        if os.path.exists(candidate):
            return candidate
    # recursive search
    pattern = os.path.join(real_video_dir, "**", stem + ".*")
    matches = sorted(glob.glob(pattern, recursive=True))
    return matches[0] if matches else None

def render_sequence_from_blendshapes(blendshapes, shape_t, out_video_path, fps=FPS, chunk_size=CHUNK_SIZE):
    try:
        import imageio.v2 as imageio
    except Exception as exc:
        print(f"imageio not available: {exc}")
        return False
    T = blendshapes.shape[0]
    cam_base = torch.tensor([5, 0, 0], dtype=torch.float32).unsqueeze(0).to(device)
    # Prefer ffmpeg CLI (matches test_talkinghead behavior)
    if shutil.which("ffmpeg"):
        seq_dir = out_video_path + "_frames"
        os.makedirs(seq_dir, exist_ok=True)
        for start in range(0, T, chunk_size):
            end = min(T, start + chunk_size)
            chunk = blendshapes[start:end]
            expr = chunk[:, :50]
            gpose = chunk[:, 50:53]
            jaw = chunk[:, 53:56]
            eyelids = chunk[:, 56:]
            shape_chunk = shape_t[start:end]
            with torch.no_grad():
                verts = get_vertices_from_blendshapes(expr, gpose, jaw, eyelids, shape_chunk)
                cam = cam_base.expand(verts.shape[0], -1)
                frames = renderer.forward(verts, cam)["rendered_img"]
            frames_np = (frames.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).clip(0, 255).astype(np.uint8)
            for i, frame in enumerate(frames_np, start=start):
                imageio.imwrite(os.path.join(seq_dir, f"{i:06d}.png"), frame)
            del verts, frames, frames_np
            torch.cuda.empty_cache()
        cmd = [
            "ffmpeg", "-y",
            "-framerate", str(fps),
            "-i", os.path.join(seq_dir, "%06d.png"),
            "-c:v", "libx264",
            "-preset", FFMPEG_PRESET,
            "-crf", FFMPEG_CRF,
            "-pix_fmt", "yuv420p",
            out_video_path
        ]
        try:
            subprocess.run(cmd, check=True)
            return True
        except Exception as exc:
            print(f"ffmpeg encode failed: {exc}")
            return False
    # Fallback to imageio FFMPEG writer if available
    try:
        writer = imageio.get_writer(out_video_path, fps=fps, codec="libx264", format="FFMPEG")
    except Exception as exc:
        print(f"FFMPEG writer not available ({exc}); falling back to PNG sequence.")
        seq_dir = out_video_path + "_frames"
        os.makedirs(seq_dir, exist_ok=True)
        for start in range(0, T, chunk_size):
            end = min(T, start + chunk_size)
            chunk = blendshapes[start:end]
            expr = chunk[:, :50]
            gpose = chunk[:, 50:53]
            jaw = chunk[:, 53:56]
            eyelids = chunk[:, 56:]
            shape_chunk = shape_t[start:end]
            with torch.no_grad():
                verts = get_vertices_from_blendshapes(expr, gpose, jaw, eyelids, shape_chunk)
                cam = cam_base.expand(verts.shape[0], -1)
                frames = renderer.forward(verts, cam)["rendered_img"]
            frames_np = (frames.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).clip(0, 255).astype(np.uint8)
            for i, frame in enumerate(frames_np, start=start):
                imageio.imwrite(os.path.join(seq_dir, f"{i:06d}.png"), frame)
            del verts, frames, frames_np
            torch.cuda.empty_cache()
        return True
    with writer:
        for start in range(0, T, chunk_size):
            end = min(T, start + chunk_size)
            chunk = blendshapes[start:end]
            expr = chunk[:, :50]
            gpose = chunk[:, 50:53]
            jaw = chunk[:, 53:56]
            eyelids = chunk[:, 56:]
            shape_chunk = shape_t[start:end]
            with torch.no_grad():
                verts = get_vertices_from_blendshapes(expr, gpose, jaw, eyelids, shape_chunk)
                cam = cam_base.expand(verts.shape[0], -1)
                frames = renderer.forward(verts, cam)["rendered_img"]
            frames_np = (frames.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).clip(0, 255).astype(np.uint8)
            for frame in frames_np:
                writer.append_data(frame)
            del verts, frames, frames_np
            torch.cuda.empty_cache()
    return True

def decode_with_vq_model(blendshapes, model):
    blendshapes_b = blendshapes.unsqueeze(0)
    mask = torch.ones((1, blendshapes_b.shape[1]), device=blendshapes_b.device, dtype=torch.bool)
    with torch.no_grad():
        decoded, _ = model(blendshapes_b, mask)
    return decoded.squeeze(0)

def predict_from_audio(audio_path, model, target_style):
    speech_array, _ = librosa.load(audio_path, sr=16000)
    audio_feature = np.squeeze(wav2vec_processor(speech_array, sampling_rate=16000).input_values)
    audio_feature = np.reshape(audio_feature, (-1, audio_feature.shape[0]))
    audio_feature = torch.FloatTensor(audio_feature).to(device=device)
    with torch.no_grad():
        blendshapes_out = model.predict_no_quantizer(audio_feature, target_style=target_style)
    return blendshapes_out.squeeze(0)

style_path = "/mnt/fasttalk/demo/styles/style_2.npz"
target_style = None

for npz_path in npz_files:
    stem = os.path.splitext(os.path.basename(npz_path))[0]
    with np.load(npz_path) as npz_data:
        extracted = extract_blendshapes_from_npz(npz_data)
        if extracted is None:
            print(f"Skipping {npz_path}: missing exp/pose data.")
            continue
        blendshapes, shape_t = extracted
        duration_sec = float(blendshapes.shape[0]) / float(FPS)
        out_video = os.path.join(output_dir, f"{stem}.mp4")
        print(f"Rendering {stem} -> {out_video}")
        ok = render_sequence_from_blendshapes(blendshapes, shape_t, out_video, fps=FPS, chunk_size=CHUNK_SIZE)
        if not ok:
            continue
        # FASTTALK S1: encode/decode and render
        s1_blendshapes = decode_with_vq_model(blendshapes, vq_model)
        s1_video = os.path.join(output_dir, f"{stem}_s1.mp4")
        ok_s1 = render_sequence_from_blendshapes(s1_blendshapes, shape_t, s1_video, fps=FPS, chunk_size=CHUNK_SIZE)
        if not ok_s1:
            print(f"S1 render failed for {stem}")
            continue
        # Audio-only (S2) prediction with fixed style
        audio_path = find_audio_for_stem(stem)
        s2_video = None
        if audio_path:
            s2_blendshapes = predict_from_audio(audio_path, s2_model, target_style)
            min_len = min(s2_blendshapes.shape[0], shape_t.shape[0])
            s2_blendshapes = s2_blendshapes[:min_len]
            shape_t_s2 = shape_t[:min_len]
            s2_video = os.path.join(output_dir, f"{stem}_s2_audio.mp4")
            ok_s2 = render_sequence_from_blendshapes(s2_blendshapes, shape_t_s2, s2_video, fps=FPS, chunk_size=CHUNK_SIZE)
            if not ok_s2:
                print(f"S2 audio-only render failed for {stem}")
                s2_video = None
        else:
            print(f"No audio found for S2 prediction: {stem}")
        # Side-by-side with real video (real left, tracking right), using real audio
        real_video = find_real_video_for_stem(stem)
        if shutil.which("ffmpeg"):
            if real_video:
                print(f"Real video found: {real_video}")
                side_by_side = os.path.join(output_dir, f"{stem}_side_by_side.mp4")
                cmd = [
                    "ffmpeg", "-y",
                    "-i", real_video,
                    "-i", out_video,
                    "-filter_complex",
                    "[0:v]setpts=PTS-STARTPTS[v0];[1:v]setpts=PTS-STARTPTS[v1];[v0][v1]hstack=inputs=2[v]",
                    "-map", "[v]",
                    "-map", "0:a?",
                    "-t", str(duration_sec),
                    "-c:v", "libx264",
                    "-preset", FFMPEG_PRESET,
                    "-crf", FFMPEG_CRF,
                    "-c:a", "copy",
                    "-shortest",
                    side_by_side
                ]
                try:
                    subprocess.run(cmd, check=True)
                    print(f"Saved side-by-side: {side_by_side}")
                except Exception as exc:
                    print(f"Side-by-side failed for {stem}: {exc}")
            else:
                print(f"No real video found for {stem}")
        if shutil.which("ffmpeg"):
            audio_source = None
            if audio_path:
                audio_source = audio_path
                print(f"Audio file found: {audio_source}")
            elif real_video:
                audio_source = real_video
                print(f"Using audio from real video: {audio_source}")
            if audio_source:
                out_with_audio = os.path.join(output_dir, f"{stem}_audio.mp4")
                cmd = [
                    "ffmpeg", "-y",
                    "-i", out_video,
                    "-i", audio_source,
                    "-map", "0:v",
                    "-map", "1:a?",
                    "-t", str(duration_sec),
                    "-c:v", "copy",
                    "-c:a", "aac",
                    "-shortest",
                    out_with_audio
                ]
                try:
                    subprocess.run(cmd, check=True)
                    print(f"Saved with audio: {out_with_audio}")
                except Exception as exc:
                    print(f"Audio mux failed for {stem}: {exc}")
            else:
                print(f"No matching audio or real video found for {stem}")
        elif audio_path or real_video:
            print(f"ffmpeg not found; audio not muxed for {stem}")
        else:
            print(f"No matching audio or real video found for {stem}")
        # Quad concat: GT / SMIRK TRACKING / FASTTALK S1 / AUDIO-ONLY (S2)
        if shutil.which("ffmpeg") and real_video and s1_video and s2_video:
            quad_out = os.path.join(output_dir, f"{stem}_gt_tracking_s1_s2audio.mp4")
            cmd = [
                "ffmpeg", "-y",
                "-i", real_video,
                "-i", out_video,
                "-i", s1_video,
                "-i", s2_video,
                "-filter_complex",
                "[0:v]setpts=PTS-STARTPTS[v0];[1:v]setpts=PTS-STARTPTS[v1];[2:v]setpts=PTS-STARTPTS[v2];[3:v]setpts=PTS-STARTPTS[v3];[v0][v1][v2][v3]hstack=inputs=4[v]",
                "-map", "[v]",
                "-map", "0:a?",
                "-t", str(duration_sec),
                "-c:v", "libx264",
                "-preset", FFMPEG_PRESET,
                "-crf", FFMPEG_CRF,
                "-c:a", "copy",
                "-shortest",
                quad_out
            ]
            try:
                subprocess.run(cmd, check=True)
                print(f"Saved quad concat: {quad_out}")
            except Exception as exc:
                print(f"Quad concat failed for {stem}: {exc}")
        elif shutil.which("ffmpeg"):
            print(f"Quad concat skipped for {stem} (missing real/s1/s2 video)")
        else:
            print(f"ffmpeg not found; quad concat not created for {stem}")