In [None]:
import os
import re
import json
import gzip
import pickle
import math
import random
import warnings
from pathlib import Path
from decimal import Decimal
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from functools import partial

import numpy as np
import pandas as pd
import torch
import torchaudio
from torch_audiomentations import ApplyImpulseResponse, AddBackgroundNoise
from textgrid import TextGrid
from tqdm.auto import tqdm

# Ignore unnecessary warnings
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*deprecated.*")

In [None]:
class Augmentor:
    """Handles Audio Augmentation (RIR, Background Noise)."""
    def __init__(self, rir_paths=None, noise_paths=None, rir_prob=0, noise_prob=0, min_snr_in_db=3, max_snr_in_db=30):
        self.do_aug = False if rir_prob==0 and noise_prob==0 else True
        if not self.do_aug: return
        
        self.rir_augmentor = ApplyImpulseResponse(
            ir_paths=rir_paths, sample_rate=16000, p=rir_prob, output_type='tensor'
        ) if rir_paths else None
        
        self.noise_augmentor = AddBackgroundNoise(
            background_paths=noise_paths, sample_rate=16000, 
            min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db,
            p=noise_prob, output_type='tensor'
        ) if noise_paths else None

    def __call__(self, wav):
        return self.augment(wav)

    def augment(self, wav):
        if not self.do_aug: return wav
        if wav.dim() == 1: wav = wav.unsqueeze(0).unsqueeze(0) # (1, 1, T)
        elif wav.dim() == 2: wav = wav.unsqueeze(0) # (1, C, T)

        if self.rir_augmentor: wav = self.rir_augmentor(wav)
        if self.noise_augmentor: wav = self.noise_augmentor(wav)
        
        return wav.squeeze(0) # Return (C, T)

class Extractor:
    """Handles Feature Extraction (LogMel Spectrogram + CMVN)."""
    def __init__(self, apply_cmvn=True, **feature_args):
        self.apply_cmvn = apply_cmvn
        basic_args = {'sample_rate':16000, 'n_fft':400, 'n_mels':24, 'win_length':400, 'hop_length':160}
        basic_args.update(feature_args)
        self.extractor = torchaudio.transforms.MelSpectrogram(**basic_args)

    def __call__(self, wav):
        return self.extract(wav)

    def extract(self, wav: torch.Tensor):
        if wav.dim() == 1: wav = wav.unsqueeze(0) # Ensure (1, T) for mono
        
        spec = self.extractor(wav)  # (1, F, T)
        spec = spec.squeeze(0).transpose(0, 1)  # (T, F)
        spec = torch.log10(spec + 1e-6)
        
        if self.apply_cmvn:
            mean = spec.mean(dim=0, keepdim=True)
            std = spec.std(dim=0, keepdim=True)
            spec = (spec - mean) / (std + 1e-9)
            
        return spec

In [None]:
import re
import ast
from decimal import Decimal
import numpy as np
import pandas as pd
from pathlib import Path
from textgrid import TextGrid
from tqdm.auto import tqdm

SR = 16000
HOP = 160  # 10ms
VOWEL_PATTERN = re.compile(r"^(AA|AE|AH|AO|AW|AY|EH|ER|EY|IH|IY|OW|OY|UH|UW)\d$")

def count_words_and_chars(text: str):
    cleaned = re.sub(r'^\W+|\W+$', '', text)
    words = cleaned.split()
    num_words = len(words)
    num_chars = sum(c.isalpha() for c in cleaned)
    return num_words, num_chars

def parse_speakers_txt(speakers_txt_path: Path):
    speakers_info = {}
    if not speakers_txt_path.exists(): return {}
    with open(speakers_txt_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith(";"): continue
            parts = line.strip().split("|")
            if len(parts) >= 3:
                speaker_id = parts[0].strip()
                gender = parts[1].strip()
                subset = parts[2].strip()
                speakers_info[speaker_id] = {"gender": gender, "subset": subset}
    return speakers_info

def parse_chapters_txt(chapters_txt_path: Path):
    chapters_info = {}
    if not chapters_txt_path.exists(): return {}
    with open(chapters_txt_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith(";") or not line.strip(): continue
            parts = line.strip().split("|")
            if len(parts) < 7: continue
            chapter_id = parts[0].strip()
            book_id = parts[4].strip()
            chapter_title = parts[6].strip()
            try: duration = float(parts[2].strip())
            except: duration = None
            chapters_info[chapter_id] = {
                "book_id": book_id,
                "chapter_title": chapter_title,
                "duration": duration
            }
    return chapters_info


def parse_textgrid_segments_with_blanks(tg_path: Path):
    """
    Define segments including whitespace ("") intervals:
    - seg0: start=0.0, end=the first whitespace xmax (if none, use the last word xmax of the first chunk)
    - regular segments: start=previous whitespace xmin, end=next whitespace xmax
    - if there is no "next whitespace" for the final segment, set end=last word xmax
    """

    tg = TextGrid.fromFile(str(tg_path))

    word_tier = next((t for t in tg.tiers if t.name and t.name.lower() in ["words", "word"]), None)
    phone_tier = next((t for t in tg.tiers if t.name and t.name.lower() in ["phones", "phonemes"]), None)
    if not word_tier:
        return []

    intervals = [(Decimal(str(iv.minTime)), Decimal(str(iv.maxTime)), (iv.mark or "").strip())
                 for iv in word_tier.intervals]

    segments = []
    prev_blank = None
    i, L = 0, len(intervals)

    while i < L:
        xmin, xmax, mark = intervals[i]

        # Whitespace intervals are recorded as prev_blank and skipped
        if mark == "":
            prev_blank = (xmin, xmax)
            i += 1
            continue

        # Collect word chunk
        chunk_words = []
        last_word_xmax = xmax
        while i < L:
            xxmin, xxmax, m = intervals[i]
            if m == "":
                break
            chunk_words.append((m, xxmin, xxmax))
            last_word_xmax = xxmax
            i += 1

        # Next whitespace
        next_blank = intervals[i] if (i < L and intervals[i][2] == "") else None

        # Calculate boundaries
        if prev_blank is None:
            seg_start = Decimal("0.0")  # First segment of the file
        else:
            seg_start = prev_blank[0]   # Previous whitespace xmin

        if next_blank is not None:
            seg_end = next_blank[1]     # Next whitespace xmax
        else:
            seg_end = last_word_xmax    # If no trailing whitespace, use last word xmax

        if chunk_words:
            # Count syllables (phones)
            num_syllables = 0
            if phone_tier:
                for p in phone_tier.intervals:
                    pmin = Decimal(str(p.minTime))
                    pmax = Decimal(str(p.maxTime))
                    if pmax <= seg_start or pmin >= seg_end:
                        continue
                    if VOWEL_PATTERN.match((p.mark or "").strip()):
                        num_syllables += 1

            segments.append({
                "words": chunk_words,      # Absolute coordinates
                "start": seg_start,
                "end": seg_end,
                "duration": (seg_end - seg_start),
                "num_syllables": num_syllables
            })

        # If i is whitespace, update prev_blank for the next loop
        if i < L and intervals[i][2] == "":
            prev_blank = (intervals[i][0], intervals[i][1])
            i += 1

    return segments

# =========================
# Frame-level label generation
# =========================
def _times_to_frame_indices(times_sec: np.ndarray, sr: int, hop: int, total_frames: int):
    if times_sec.size == 0:
        return np.empty((0,), dtype=np.int32)
    frames = np.floor(times_sec * (sr / hop)).astype(np.int32)
    return frames[(frames >= 0) & (frames < total_frames)]

def make_frame_label_for_segment(seg_start: float, seg_end: float, word_abs_starts: np.ndarray):
    """
    Given a segment absolute time window [seg_start, seg_end] and an array of
    absolute word start times (word_abs_starts):
    - relative start time = word_abs_start - seg_start
    - total_frames = floor(duration * sr / hop) + 1
    - return a vector (list[int]) with 1s only at onset frames
    """
    duration = float(seg_end - seg_start)
    total_frames = int(np.floor(duration * SR / HOP)) + 1
    if total_frames <= 0:
        return [], total_frames

    rel_starts = word_abs_starts - float(seg_start)
    frame_indices = _times_to_frame_indices(rel_starts, SR, HOP, total_frames)

    label = np.zeros((total_frames,), dtype=np.int8)
    if frame_indices.size > 0:
        label[frame_indices] = 1
    return label.tolist(), total_frames

def process_textgrid_file(tg_path: Path,
                          audio_base_dir: Path,
                          speakers_info: dict,
                          chapters_info: dict,
                          subset: str,
                          speaker_id: str,
                          chapter_id: str):
    """
    Process a single TextGrid file and return metadata rows for each segment.
    """
    rows = []
    try:
        segments = parse_textgrid_segments_with_blanks(tg_path)
    except Exception as e:
        print(f"Error parsing {tg_path}: {e}")
        return []
        
    if not segments:
        return rows

    utt_id_prefix = tg_path.stem
    audio_path = audio_base_dir / speaker_id / chapter_id / f"{utt_id_prefix}.flac"

    for i, seg in enumerate(segments):
        # Text / statistics
        full_text = " ".join([w[0] for w in seg["words"]])
        num_words, num_chars = count_words_and_chars(full_text)
        num_syllables = int(seg["num_syllables"])

        # Absolute word start/end
        word_abs_starts = np.array([float(w[1]) for w in seg["words"]], dtype=np.float64)
        
        # Generate frame_label (immediately at meta stage)
        frame_label_list, total_frames = make_frame_label_for_segment(seg["start"], seg["end"], word_abs_starts)

        # words_info: absolute time (or relative time if needed later)
        words_info = [
            {"text": w[0], "start": float(w[1]), "end": float(w[2])} for w in seg["words"]
        ]

        rows.append({
            "utt_id": f"{utt_id_prefix}_seg{i}",
            "subset": subset,
            "speaker_id": speaker_id,
            "chapter_id": chapter_id,
            "transcript": full_text,
            "num_words": num_words,
            "num_chars": num_chars,
            "num_syllables": num_syllables,
            "start_time": float(seg["start"]),
            "end_time": float(seg["end"]),
            "duration": float(seg["duration"]),
            "audio_path": str(audio_path),
            "relative_audio_path": str(audio_path).split("LibriSpeech/LibriSpeech/")[-1] if "LibriSpeech/LibriSpeech/" in str(audio_path) else str(audio_path),
            "gender": speakers_info.get(speaker_id, {}).get("gender", "NA"),
            "book_id": chapters_info.get(chapter_id, {}).get("book_id", "NA"),
            "chapter_title": chapters_info.get(chapter_id, {}).get("chapter_title", "NA"),
            "chapter_duration": chapters_info.get(chapter_id, {}).get("duration", "NA"),
            # frame-level
            "frame_label_list": frame_label_list,                  # python list (원본)
            "frame_label": ",".join(map(str, frame_label_list)),   # CSV 저장 친화
            "n_frames": total_frames,
            "words_info": words_info                               # JSON 직렬화 가능
        })

    return rows


# =========================
# Reader Directory Processing
# =========================
def process_reader_dir(reader_dir: Path,
                       alignment_base_dir: Path,
                       audio_base_dir: Path,
                       speakers_info: dict,
                       chapters_info: dict):
    """
    reader_dir = alignments/<subset>/<speaker>
    """
    rows = []

    subset = reader_dir.parent.name
    speaker_id = reader_dir.name

    for chapter_dir in sorted([d for d in reader_dir.glob("*") if d.is_dir()]):
        chapter_id = chapter_dir.name
        for tg_path in sorted(chapter_dir.glob("*.TextGrid")):
            rows.extend(
                process_textgrid_file(
                    tg_path=tg_path,
                    audio_base_dir=audio_base_dir / subset,
                    speakers_info=speakers_info,
                    chapters_info=chapters_info,
                    subset=subset,
                    speaker_id=speaker_id,
                    chapter_id=chapter_id
                )
            )
    return rows    

def load_librispeech_counts_metadata(data_dir, align_dir, subset_filters=None):
    """
    WRAPPER: Uses new logic to load metadata
    """
    data_path = Path(data_dir)
    align_path = Path(align_dir)
    
    # Try to find metadata files
    speakers_map = {}
    chapters_map = {}
    
    candidates = [data_path, data_path / "LibriSpeech", data_path.parent]
    
    spk_file = next((p / "SPEAKERS.TXT" for p in candidates if (p / "SPEAKERS.TXT").exists()), None)
    ch_file = next((p / "CHAPTERS.TXT" for p in candidates if (p / "CHAPTERS.TXT").exists()), None)
    
    if spk_file: 
        print(f"Loading speakers info from {spk_file}")
        speakers_map = parse_speakers_txt(spk_file)
    if ch_file: 
        print(f"Loading chapters info from {ch_file}")
        chapters_map = parse_chapters_txt(ch_file)
    
    all_rows = []
    
    # align_dir is expected to be e.g. .../train-clean-100
    if not align_path.exists():
        print(f"Align dir {align_path} not found.")
        return pd.DataFrame()
        
    children = [d for d in align_path.iterdir() if d.is_dir() and d.name.isdigit()]
    print(f"Found {len(children)} speaker folders in {align_path}")
    
    for spk_dir in tqdm(children, desc="Processing Speakers"):
        audio_root = data_path
        
        rows = process_reader_dir(spk_dir, align_path.parent, audio_root, speakers_map, chapters_map)
        all_rows.extend(rows)
            
    return pd.DataFrame(all_rows)

In [None]:
# Multiprocessing Workers

_augmentor = None
_extractor = None

def _init_worker(aug_args, feature_args):
    global _augmentor, _extractor
    _augmentor = Augmentor(**aug_args)
    _extractor = Extractor(**feature_args)

def _process_item(row):
    """
    Process a single metadata row: Load Audio -> Augment -> Extract Feature
    """
    audio_path = row['audio_path']
    utt_id = row['utt_id']
    start_time = row.get('start_time', 0.0)
    end_time = row.get('end_time', None)
    
    try:
        # Load Audio (Full File)
        wav, sr = torchaudio.load(audio_path)
        if sr != 16000:
            wav = torchaudio.functional.resample(wav, sr, 16000)
        wav = wav.mean(dim=0) # Mix to mono if stereo
        
        # Crop segment if start/end time provided (from TextGrid segmentation)
        # Note: The new metadata logic provides segments. 
        if end_time is not None:
             s_sample = int(start_time * 16000)
             e_sample = int(end_time * 16000)
             if e_sample > wav.size(0): e_sample = wav.size(0)
             if s_sample < wav.size(0):
                 wav = wav[s_sample:e_sample]
                 
        # Augment & Extract
        if _augmentor: wav = _augmentor(wav)
        feat = _extractor(wav) # (T, n_mels)
        
        # --- Frame Labels (Onset) ---
        T = feat.shape[0]
        
        # Use pre-calculated frame_label_list from metadata
        frame_label_list = row.get('frame_label_list', [])
        
        # Handle string serialization from CSV
        if isinstance(frame_label_list, str):
            try:
                frame_label_list = ast.literal_eval(frame_label_list)
            except:
                frame_label_list = []
                
        # If list is empty or missing, fallback or zeros
        if not frame_label_list:
             frame_labels = np.zeros(T, dtype=np.float32)
        else:
             # Truncate or pad to match feature length T
             # Metadata total_frames might differ slightly due to rounding vs stft padding
             fl = np.array(frame_label_list, dtype=np.float32)
             if len(fl) > T:
                 frame_labels = fl[:T]
             elif len(fl) < T:
                 frame_labels = np.pad(fl, (0, T - len(fl)))
             else:
                 frame_labels = fl
        
        return {
            'utt_id': utt_id,
            'feat': feat.cpu().numpy(), # Save as numpy to save space/pickle
            'frame_label': frame_labels,
            'word_count': row['num_words'],
            'syllable_count': row['num_syllables'],
            'duration': row['duration'],
            'audio_path': audio_path
        }
    except Exception as e:
        # print(f"Failed {audio_path}: {e}")
        return None

def generate_dataset_parallel(meta_df, save_path, aug_args=None, feature_args=None, num_workers=None):
    if num_workers is None: num_workers = max(1, cpu_count() - 2)
    
    if aug_args is None: aug_args = {}
    if feature_args is None: feature_args = {}
    
    # Convert DataFrame to list of dicts for pool
    data_list = meta_df.to_dict('records')
    results = []

    print(f"Processing {len(data_list)} files with {num_workers} workers...")
    
    with Pool(processes=num_workers, initializer=_init_worker, initargs=(aug_args, feature_args)) as pool:
        for res in tqdm(pool.imap_unordered(_process_item, data_list), total=len(data_list)):
            if res is not None:
                results.append(res)
    
    # Save Results
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        print(f"Saving {len(results)} items to {save_path}...")
        torch.save(results, save_path) # Save as list of dicts
        
    return results

In [None]:
# --- MAIN EXECUTION ---

# Define Paths
libri_data_dir = ...
libri_algn_dir = ...
libri_processed_dir = ...

# Define Arguments
feature_args = {'n_mels': 24, 'apply_cmvn': True}

# Paths for augmentation resources
aug_args = {
    'rir_paths': ..., # UPDATE THIS
    'noise_paths': ..., # UPDATE THIS
    'rir_prob': 0.5, 
    'noise_prob': 0.5
}

noaug_args = {'rir_prob': 0, 'noise_prob': 0}
splits = ["test-clean", "dev-clean", "test-other", "dev-other", "train-clean-100", "train-clean-360", "train-other-500"]

grouped_files = {
    'train': [],
    'valid': [],
    'test': []
}

processed_dir = os.path.join(libri_processed_dir, "processed")
meta_dir = os.path.join(libri_processed_dir, "meta")

os.makedirs(processed_dir, exist_ok=True)
os.makedirs(meta_dir, exist_ok=True)

for split in splits:
    print(f"\n=== Processing {split} ===")
        
    meta_path = os.path.join(meta_dir, f"librispeech_{split.replace('_','-')}_utt_meta_withframe.csv")
    feat_path = os.path.join(processed_dir, f"librispeech_{split.replace('_','-')}_logmel_features_withframe.pt")

    # Categorize splits for later merging
    if "train" in split:
        grouped_files['train'].append(feat_path)
    elif "dev" in split:
        grouped_files['valid'].append(feat_path)
    elif "test" in split:
        grouped_files['test'].append(feat_path)

    # 1. Generate Metadata
    meta_df = None
    if os.path.exists(meta_path):
        print(f"Loading existing metadata from {meta_path}...")
        meta_df = pd.read_csv(meta_path)
        # Check if 'frame_label_list' column exists (needed for frame labels)
        if 'frame_label_list' not in meta_df.columns:
            print("Metadata missing 'frame_label_list'. Regenerating...")
            meta_df = None

    if meta_df is None:
        print("Generating metadata...")
        subset_align_dir = os.path.join(libri_algn_dir, split.replace('_','-'))
        if not os.path.exists(subset_align_dir):
            subset_align_dir = libri_algn_dir 
            
        meta_df = load_librispeech_counts_metadata(libri_data_dir, subset_align_dir, subset_filters=[split.replace('_','-')])
        # Save frame_label_list as literal string for CSV
        meta_df.to_csv(meta_path, index=False)
        print(f"Saved metadata to {meta_path}")

    # 2. Generate Features
    is_train = "train" in split
    current_aug = aug_args if is_train else noaug_args
    
    # Check if feature file already exists
    if os.path.exists(feat_path):
        print(f"Features already exist: {feat_path}")
        continue
    
    # Process Parallel & Collect Shards
    print(f"Starting parallel processing for features...")
    shards_dir = os.path.join(processed_dir, f"{split}_shards")
    os.makedirs(shards_dir, exist_ok=True)
    
    # Generate & Save Shards (Chunked Processing)
    dataset_results = generate_dataset_parallel(
        meta_df, 
        None, 
        aug_args=current_aug, 
        feature_args=feature_args
    )
    
    # 3. Merge Shards/Results into Final Dictionary
    print("Merging results...")
    final_data = {}
    for res in dataset_results:
        utt_id = res['utt_id']
        final_data[utt_id] = {
            "feat": torch.tensor(res['feat']), 
            "frame_label": torch.tensor(res['frame_label']), # Include frame_label
            "word_count": res['word_count'],
            "syllable_count": res['syllable_count'],
            "duration": res['duration'],
            "audio_path": res.get('audio_path', '')
        }
        
    print(f"Saving merged file to {feat_path}...")
    torch.save(final_data, feat_path)
    print("Saved.")
    
# --- Aggregate Category Merging ---
print("\n=== Merging Categories (Total Train, Valid, Test) ===")

for category, file_paths in grouped_files.items():
    print(f"Merging {category} set from {len(file_paths)} files...")
    if not file_paths:
        print(f"No files found for {category}, skipping.")
        continue
        
    merged_data = {}
    for fp in file_paths:
        if os.path.exists(fp):
            print(f"  Loading {fp}...")
            data_part = torch.load(fp)
            merged_data.update(data_part)
        else:
            print(f"  Warning: File {fp} not found, skipping.")
    
    save_name = f"total_{category}.pt"
    save_path = os.path.join(processed_dir, save_name)
    print(f"Saving combined {category} set ({len(merged_data)} samples) to {save_path}...")
    torch.save(merged_data, save_path)
