In [None]:
# ==============================
# Clean dataset pipeline (PyTorch-ready)
# ==============================
from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
from collections import Counter

import numpy as np
import torch

@dataclass(frozen=True)
class SampleMeta:
    """Lightweight metadata holder for each audio segment."""
    subject_id: str
    vowel_type: str
    condition: str
    filename: str
    segment_id: int
    sr: int

PreprocessFn = Callable[[str, Optional[str]], Tuple[List[np.ndarray], List[np.ndarray]]]

def parse_filename(file_stem: str) -> Tuple[str, str, str]:
    parts = file_stem.split('-')
    subject_id = parts[0] if len(parts) > 0 and parts[0] else "unknown"
    vowel_type = parts[1] if len(parts) > 1 and parts[1] else "a"
    condition  = parts[2] if len(parts) > 2 and parts[2] else "unknown"
    return subject_id, vowel_type, condition

def build_domain_index(vowels: Iterable[str]) -> Dict[str, int]:
    uniq = sorted(set(vowels))
    return {v: idx for idx, v in enumerate(uniq)}

def map_condition_to_task(condition: str) -> int:
    mapping = {"h": 1, "l": 0, "n": 0, "lhl": 1}
    return mapping.get(condition, 0)

def process_dataset(audio_files: Sequence, preprocess_fn: PreprocessFn, max_files: Optional[int] = None, default_sr: int = 44100):
    if not audio_files:
        print("❌ 'audio_files' está vacío.")
        return []
    files_to_process = list(audio_files[:max_files]) if max_files else list(audio_files)
    print(f"🔄 Procesando {len(files_to_process)} archivos...")
    dataset = []
    for i, file_path in enumerate(files_to_process):
        subject_id, vowel_type, condition = parse_filename(getattr(file_path, "stem", str(file_path)))
        try:
            spectrograms, segments = preprocess_fn(file_path, vowel_type=vowel_type)
        except Exception as e:
            print(f"⚠️ Error en {file_path}: {e}")
            continue
        if not spectrograms:
            continue
        for j, (spec, seg) in enumerate(zip(spectrograms, segments)):
            dataset.append({
                "spectrogram": spec,
                "segment": seg,
                "metadata": SampleMeta(subject_id, vowel_type, condition, getattr(file_path, "name", str(file_path)), j, default_sr)
            })
    print(f"✅ {len(dataset)} muestras generadas")
    return dataset

def to_pytorch_tensors(dataset: List[Dict]):
    if not dataset:
        print("❌ Dataset vacío.")
        return None, None, None, []
    metas = [sample["metadata"] for sample in dataset]
    vowels = [m.vowel_type for m in metas]
    domain_index = build_domain_index(vowels)
    specs, y_task, y_domain = [], [], []
    for sample in dataset:
        spec = sample["spectrogram"]
        specs.append(np.expand_dims(spec, axis=0))
        y_task.append(map_condition_to_task(sample["metadata"].condition))
        y_domain.append(domain_index[sample["metadata"].vowel_type])
    X = torch.from_numpy(np.stack(specs, axis=0)).float()
    y_task_t = torch.tensor(y_task, dtype=torch.long)
    y_domain_t = torch.tensor(y_domain, dtype=torch.long)
    print("📊 Tensores listos:", X.shape, y_task_t.shape, y_domain_t.shape)
    return X, y_task_t, y_domain_t, metas

class VowelSegmentsDataset(torch.utils.data.Dataset):
    def __init__(self, X, y_task, y_domain, metas):
        self.X, self.y_task, self.y_domain, self.metas = X, y_task, y_domain, metas
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        return {"X": self.X[idx], "y_task": self.y_task[idx], "y_domain": self.y_domain[idx], "meta": self.metas[idx]}

def build_full_pipeline(audio_files: Optional[Sequence], preprocess_fn: PreprocessFn, max_files: Optional[int] = None):
    if not audio_files:
        print("❌ 'audio_files' no está definido.")
        return {"dataset": [], "tensors": (None, None, None), "torch_ds": None, "metadata": []}
    dataset = process_dataset(audio_files, preprocess_fn, max_files)
    if not dataset:
        return {"dataset": [], "tensors": (None, None, None), "torch_ds": None, "metadata": []}
    X, y_task, y_domain, metas = to_pytorch_tensors(dataset)
    torch_ds = VowelSegmentsDataset(X, y_task, y_domain, metas) if X is not None else None
    return {"dataset": dataset, "tensors": (X, y_task, y_domain), "torch_ds": torch_ds, "metadata": metas}


In [None]:
# Ejemplo de uso (ajusta audio_files y preprocess_audio_paper a tu entorno)
# results = build_full_pipeline(audio_files=audio_files, preprocess_fn=preprocess_audio_paper)
# dataset = results["dataset"]
# X, y_task, y_domain = results["tensors"]
# torch_dataset = results["torch_ds"]
