In [1]:
import torch
import os
from pathlib import Path
# change base folder
os.chdir('../')
# Load your model definition and dataset
from models import get_model
from types import SimpleNamespace
import yaml
import matplotlib.pyplot as plt
from flame_model.FLAME import FLAMEModel
from renderer.renderer import Renderer
from pytorch3d.transforms import matrix_to_euler_angles
import matplotlib.animation as animation
import numpy as np
from dataset.data_loader_joint_data import get_dataloaders

device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flame    = FLAMEModel(n_shape=300,n_exp=50).to(device)
renderer = Renderer(render_full_head=True).to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_and_flatten_yaml(config_path):
    """
    Loads the YAML file and flattens the structure so that
    all sub-keys under top-level sections (e.g., DATA, NETWORK, etc.)
    appear in a single dictionary without the top-level keys.
    """
    with open(config_path, 'r') as f:
        full_config = yaml.safe_load(f)

    # Flatten the dict by merging all sub-dicts
    flattened_config = {}
    for top_level_key, sub_dict in full_config.items():
        # sub_dict should itself be a dict of key-value pairs
        if isinstance(sub_dict, dict):
            # Merge each sub-key into flattened_config
            for k, v in sub_dict.items():
                flattened_config[k] = v
        else:
            # In case there's a non-dict top-level key (unlikely but possible)
            flattened_config[top_level_key] = sub_dict

    return SimpleNamespace(**flattened_config)

In [3]:
# ---- Load model (without DDP for eval) ----
def load_model_for_eval(checkpoint_path,cfg):
    model = get_model(cfg)
    model = model.to(device)

    checkpoint = torch.load(checkpoint_path)
    if "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint,map_location=lambda storage, loc: storage.cpu())

    model.eval()
    return model

# ---- Load dataset ----
def load_dataset(cfg,test_config):

    dataset = get_dataloaders(cfg,test_config)

    if not test_config:
        train_loader = dataset['train']
        val_loader = dataset['valid']
        return train_loader, val_loader
    else:
        test_loader = dataset['test']
        return test_loader


# ---- Evaluate some samples ----
def evaluate_samples(model, data_loader, num_samples=5):
    for i, (vertice, blendshapes, template, _) in enumerate(data_loader):

        if i >= num_samples:
            break

        vertice     = vertice.to(device)
        blendshapes = blendshapes.to(device)
        template    = template.to(device)

        with torch.no_grad():
            vertice_out, blendshapes_out, quant_loss, info = model(vertice, blendshapes, template)

        render_comparison(vertice.squeeze(), vertice_out.squeeze(), blendshapes.squeeze(), blendshapes_out.squeeze(), i)


def get_vertices_from_blendshapes(expr, gpose, jaw, eyelids):

    # Load the encoded file
    expr_tensor    = expr.to(device)
    gpose_tensor   = gpose.to(device)
    jaw_tensor     = jaw.to(device)
    eyelids_tensor = eyelids.to(device)

    target_shape_tensor = torch.zeros(expr_tensor.shape[0], 300).expand(expr_tensor.shape[0], -1).to(device)

    I = matrix_to_euler_angles(torch.cat([torch.eye(3)[None]], dim=0),"XYZ").to(device)

    eye_r    = I.clone().to(device).squeeze()
    eye_l    = I.clone().to(device).squeeze()
    eyes     = torch.cat([eye_r,eye_l],dim=0).expand(expr_tensor.shape[0], -1).to(device)

    pose = torch.cat([gpose_tensor, jaw_tensor], dim=-1).to(device)

    flame_output_only_shape,_ = flame.forward(shape_params=target_shape_tensor, 
                                              expression_params=expr_tensor, 
                                              pose_params=pose, 
                                              eye_pose_params=eyes)
    return flame_output_only_shape.detach()

# Assumes flame and renderer are already defined and on correct device

def render_comparison(vertice_gt, vertice_pred, blendshapes_gt, blendshapes_pred, index):
    # ==== Split GT and predicted blendshapes ====
    expr_gt     = blendshapes_gt[:, :50]
    gpose_gt    = blendshapes_gt[:, 50:53]
    jaw_gt      = blendshapes_gt[:, 53:56]
    eyelids_gt  = blendshapes_gt[:, 56:]

    expr_pr     = blendshapes_pred[:, :50]
    gpose_pr    = blendshapes_pred[:, 50:53]
    jaw_pr      = blendshapes_pred[:, 53:56]
    eyelids_pr = blendshapes_pred[:, 56:]

    # ==== Generate vertices ====
    verts_gt = get_vertices_from_blendshapes(expr_gt, gpose_gt, jaw_gt, eyelids_gt) # vertice_gt.reshape(-1,5023,3)
    verts_pr = get_vertices_from_blendshapes(expr_pr, gpose_pr, jaw_pr, eyelids_pr) # vertice_pred.reshape(-1,5023,3) 
    print(verts_gt.shape, verts_pr.shape)

    # ==== Camera ====
    cam = torch.tensor([5, 0, 0], dtype=torch.float32).unsqueeze(0).to(verts_gt.device)
    cam = cam.expand(verts_gt.shape[0], -1)

    # ==== Render both sequences ====
    frames_gt = renderer.forward(verts_gt, cam)['rendered_img']         # [T, 3, H, W]
    frames_pr = renderer.forward(verts_pr, cam)['rendered_img']         # [T, 3, H, W]

    # ==== Prepare output folder ====
    os.makedirs("demo/video", exist_ok=True)
    video_file = f"demo/video/sample_{index:03d}.mp4"

    # ==== Create animation ====
    def update(frame_idx, gt_seq, pr_seq, axes):
        gt = gt_seq[frame_idx].detach().cpu().numpy().transpose(1, 2, 0)
        pr = pr_seq[frame_idx].detach().cpu().numpy().transpose(1, 2, 0)
        combined = np.concatenate([(gt * 255).astype(np.uint8), (pr * 255).astype(np.uint8)], axis=1)
        axes.clear()
        axes.imshow(combined)
        axes.axis("off")

    fig, ax = plt.subplots(figsize=(10, 5))
    ani = animation.FuncAnimation(
        fig,
        update,
        frames=frames_gt.shape[0],
        fargs=(frames_gt, frames_pr, ax),
        interval=100
    )
    ani.save(video_file, writer='ffmpeg', fps=25)
    plt.close(fig)

    print(f"Saved video comparison to {video_file}")

In [4]:
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")

global cfg

cfg = load_and_flatten_yaml("config/joint_data/stage1.yaml")

checkpoint_path = "/root/Projects/fasttalk/logs/joint_data/joint_data_14k_s1/model_70/model.pth.tar"
model = load_model_for_eval(checkpoint_path,cfg)

train_loader, val_loader = load_dataset(cfg, test_config=False)  

Loading data...


  1%|          | 100/14181 [00:08<19:22, 12.11it/s]


Loaded data: Train-70, Val-22, Test-9


In [5]:
evaluate_samples(model, train_loader, num_samples=5)

TypeError: VQAutoEncoder.forward() takes 2 positional arguments but 4 were given

In [None]:
test_loader = load_dataset(cfg, test_config=True)

In [None]:

evaluate_samples(model, test_loader, num_samples=5)