In [1]:
import torch
import os
from pathlib import Path

# change base folder
os.chdir('../')

In [2]:
# Load your model definition and dataset
from models import get_model
from types import SimpleNamespace
import yaml
import matplotlib.pyplot as plt
from flame_model.FLAME import FLAMEModel
from renderer.renderer import Renderer
from pytorch3d.transforms import matrix_to_euler_angles
import matplotlib.animation as animation
import numpy as np
from dataset.data_loader_artalk import get_dataloaders
from base.baseTrainer import load_state_dict
import glob
import librosa
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
import torch

device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flame    = FLAMEModel(n_shape=300,n_exp=50).to(device)
renderer = Renderer(render_full_head=True).to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_and_flatten_yaml(config_path):
    """
    Loads the YAML file and flattens the structure so that
    all sub-keys under top-level sections (e.g., DATA, NETWORK, etc.)
    appear in a single dictionary without the top-level keys.
    """
    with open(config_path, 'r') as f:
        full_config = yaml.safe_load(f)

    # Flatten the dict by merging all sub-dicts
    flattened_config = {}
    for top_level_key, sub_dict in full_config.items():
        # sub_dict should itself be a dict of key-value pairs
        if isinstance(sub_dict, dict):
            # Merge each sub-key into flattened_config
            for k, v in sub_dict.items():
                flattened_config[k] = v
        else:
            # In case there's a non-dict top-level key (unlikely but possible)
            flattened_config[top_level_key] = sub_dict

    return SimpleNamespace(**flattened_config)

In [4]:
global cfg

cfg = load_and_flatten_yaml("config/artalk_ensemble/stage1.yaml")

In [5]:
# ####################### Data Loader ####################### #
from dataset.data_loader_multi import get_dataloaders
dataset = get_dataloaders(cfg)
train_loader = dataset['train']
if cfg.evaluate:
    val_loader = dataset['valid']


Loading data...


  1%|▏         | 100/7971 [00:03<05:00, 26.15it/s]

Loaded data: Train-93, Val-0, Test-8





In [6]:
output_dir = 'demo/output'

for i, (vertice, blendshapes, template, _) in enumerate(train_loader):
    if i == 5:
        break

    # Assume blendshapes has shape [1, 132, 56]
    blendshapes = blendshapes.squeeze(0).numpy()  # shape [132, 56]

    exp   = blendshapes[:, :50]   # [132, 50]
    gpose = blendshapes[:, 50:53]  # [132, 3]
    jaw   = blendshapes[:, 53:56]  # [132, 3]
    eyelids = blendshapes[:, 56:]  # [132, 3]


    base_name = f"sample_{i}"
    np.savez(os.path.join(output_dir, f"{base_name}.npz"), exp=exp, gpose=gpose, jaw=jaw, eyelids=eyelids)