In [None]:
import os
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import torchaudio
import librosa
import random
from sklearn.model_selection import train_test_split

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AUDIO_ROOT = './Data/Audio'
DATA_CSV = './Data/data_labeled_filtered.csv'
OUTPUT_DIR = './Data'

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

In [None]:
def compute_formants(y, sr, n_formants=3, order=16):
    if len(y) < order + 1:
        return np.zeros(n_formants, dtype=float)

    try:
        a = librosa.lpc(y, order=order)
        if not np.isfinite(a).all():
            raise ValueError("LPC coefficients contain NaNs/Infs")

        roots = np.roots(a)
        roots = roots[np.imag(roots) >= 0]
        freqs = np.angle(roots) * (sr / (2 * np.pi))
        freqs = np.sort(freqs)

        if len(freqs) < n_formants:
            freqs = np.pad(freqs, (0, n_formants - len(freqs)), constant_values=0.0)

        return freqs[:n_formants]
    except Exception:
        return np.zeros(n_formants, dtype=float)


def collapse(x: torch.Tensor) -> torch.Tensor:
    return x.squeeze(0).mean(dim=-1)


def extract_gender_feats_torch(y, sr, device=DEVICE):
    feats = []

    y = torch.tensor(y).float().unsqueeze(0).to(device)

    pitch_waveform = y.squeeze(0)
    freqs = torchaudio.functional.detect_pitch_frequency(pitch_waveform, sample_rate=sr,
                                                         frame_time=0.01, freq_low=50., freq_high=300.)
    freqs = freqs[freqs > 0]

    feats.append(float(freqs.mean() if freqs.numel() > 0 else 0.0))
    feats.append(float(freqs.std()  if freqs.numel() > 0 else 0.0))

    np_wave = pitch_waveform.cpu().numpy()

    feats.extend(compute_formants(np_wave, sr).tolist())

    mfcc_tf = torchaudio.transforms.MFCC(
        sample_rate=sr, n_mfcc=13,
        melkwargs={'n_fft': 512, 'hop_length': 256, 'n_mels': 32}
    ).to(device)
    delta_tf = torchaudio.transforms.ComputeDeltas().to(device)

    mfcc = mfcc_tf(y)
    mfcc_d = delta_tf(mfcc)

    feats.extend(collapse(mfcc).cpu().tolist())
    feats.extend(collapse(mfcc_d).cpu().tolist())

    try:
        sc = librosa.feature.spectral_contrast(y=np_wave, sr=sr, n_fft=512, hop_length=256)
        ch = librosa.feature.chroma_stft(y=np_wave, sr=sr, n_fft=512, hop_length=256)

        feats.extend(sc.mean(axis=1).tolist())
        feats.extend(ch.mean(axis=1).tolist())
    except:
        feats.extend(np.zeros(7 + 12, dtype=float).tolist())

    return np.array(feats, dtype=np.float32)


def augment_audio(y, sr):
    augmented = []
    noise = np.random.normal(0, 0.005, y.shape)
    augmented.append(y + noise)

    if random.random() < 0.5:
        rate = random.uniform(0.9, 1.1)
        augmented.append(librosa.effects.time_stretch(y, rate))

    pitch_shift = random.choice([-2, -1, 1, 2])
    augmented.append(librosa.effects.pitch_shift(y, sr, n_steps=pitch_shift))

    return augmented


def generate_features():
    df_meta = pd.read_csv(DATA_CSV)
    records = []

    for row in tqdm(df_meta.itertuples(), total=len(df_meta), desc="Extracting & Augmenting"):
        path = os.path.join(AUDIO_ROOT, row.path)

        try:
            y, sr = librosa.load(path, sr=16000)
        except Exception as e:
            print(f"Failed to load {path}: {e}")
            continue

        orig_feats = extract_gender_feats_torch(y, sr)
        base_id = row.Index

        records.append({"path": row.path, "base_id": base_id, **dict(zip(gender_cols, orig_feats))})

        augments = augment_audio(y, sr)
        for i, y_aug in enumerate(augments):
            if row.gender.lower() == 'female' or i == 0:  # Add more augmentations for female
                feats = extract_gender_feats_torch(y_aug, sr)
                records.append({"path": f"{row.path}_aug{i}", "base_id": base_id, **dict(zip(gender_cols, feats))})

    df_feats = pd.DataFrame(records)
    df_feats.to_csv(os.path.join(OUTPUT_DIR, "gender_features_full.csv"), index=False)
    return df_feats


def split_data(df):
    train_ids, test_ids = train_test_split(df['base_id'].unique(), test_size=0.2, random_state=42)
    train_ids, val_ids = train_test_split(train_ids, test_size=0.1, random_state=42)

    df_train = df[df['base_id'].isin(train_ids)]
    df_val = df[df['base_id'].isin(val_ids)]
    df_test = df[df['base_id'].isin(test_ids)]

    df_train.to_csv(os.path.join(OUTPUT_DIR, "train.csv"), index=False)
    df_val.to_csv(os.path.join(OUTPUT_DIR, "val.csv"), index=False)
    df_test.to_csv(os.path.join(OUTPUT_DIR, "test.csv"), index=False)

    print("Splits saved.")
    print("Train:", df_train.shape)
    print("Val:", df_val.shape)
    print("Test:", df_test.shape)