In [None]:
import torch
import os
# change base folder
os.chdir('../')
from types import SimpleNamespace
import yaml
import numpy as np
from dataset.data_loader_joint_data_batched import get_dataloaders

device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def load_and_flatten_yaml(config_path):
    """
    Loads the YAML file and flattens the structure so that
    all sub-keys under top-level sections (e.g., DATA, NETWORK, etc.)
    appear in a single dictionary without the top-level keys.
    """
    with open(config_path, 'r') as f:
        full_config = yaml.safe_load(f)

    # Flatten the dict by merging all sub-dicts
    flattened_config = {}
    for top_level_key, sub_dict in full_config.items():
        # sub_dict should itself be a dict of key-value pairs
        if isinstance(sub_dict, dict):
            # Merge each sub-key into flattened_config
            for k, v in sub_dict.items():
                flattened_config[k] = v
        else:
            # In case there's a non-dict top-level key (unlikely but possible)
            flattened_config[top_level_key] = sub_dict

    return SimpleNamespace(**flattened_config)

In [None]:
global cfg

cfg = load_and_flatten_yaml("config/joint_data/stage2.yaml")

In [None]:
test_config = False
dataset = get_dataloaders(cfg,test_config)

if not test_config:
    train_loader = dataset['train']
    val_loader   = dataset['valid']
else:
    test_loader = dataset['test']

In [None]:
padded_blendshapes, blendshape_mask, padded_audios, audio_mask = next(iter(train_loader))

In [None]:
padded_blendshapes.shape

In [None]:
blendshape_mask.shape

In [None]:
padded_audios.shape

In [None]:
audio_mask.shape

In [None]:
import IPython.display as ipd

audio_np = padded_audios[1].squeeze().numpy()  # shape: (640*T,)
audio_np.shape

In [None]:

# Play the audio (assuming 16000 Hz sampling rate)
ipd.Audio(audio_np, rate=16000)