## CNN+MLP (Proposed Model)

In [None]:
# ============================================================
# FULL CODE
# "Machine Learning Framework for Speech Intelligibility Prediction using Binaural Room Impulse Responses"
# Alfian et al., 2026
# ============================================================

import os, glob
import numpy as np
import librosa
from scipy.io import wavfile

import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

import keras_tuner as kt
from sklearn.utils import shuffle
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# =========================
# CONFIG
# =========================
DATA_ROOT = "/content/drive/MyDrive/DATASET V3"
classes = {"High Intelligibility": 1, "Low Intelligibility": 0}
SPLITS  = ["training", "validation", "testing"]

T_EARLY_END  = 0.050
T_LATE_START = 0.050
T_TOTAL      = 3.0

SR_TARGET = 22050
N_FFT     = 1024
HOP       = 512
FMIN      = 125.0
FMAX      = 8000.0
TO_DB     = True
RANDOM_STATE = 42

early_stopping = EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=False,
                               mode="min", min_delta=1e-4, verbose=1)
reduce_lr      = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6,
                                   mode='min', verbose=1)

# =========================
# HELPERS
# =========================
def list_wavs(split_dir: str):
    return (
        glob.glob(os.path.join(split_dir, "*.wav")) +
        glob.glob(os.path.join(split_dir, "*.Wav")) +
        glob.glob(os.path.join(split_dir, "*.WAV"))
    )

def safe_log(x, eps=1e-12):
    return np.log(np.maximum(x, eps))

def safe_div(a, b, eps=1e-12):
    return a / (b + eps)

def load_audio_stereo(file_path, sr_target=22050, duration=3.0):
    fs, y = wavfile.read(file_path)

    if np.issubdtype(y.dtype, np.integer):
        y = y.astype(np.float32) / np.iinfo(y.dtype).max
    else:
        y = y.astype(np.float32)

    if y.ndim == 1:
        y = np.stack([y, y], axis=1)
    if y.shape[1] != 2:
        raise ValueError("Input must be mono or stereo 2-channel.")

    if fs != sr_target:
        yL = librosa.resample(y[:, 0], orig_sr=fs, target_sr=sr_target)
        yR = librosa.resample(y[:, 1], orig_sr=fs, target_sr=sr_target)
        y = np.stack([yL, yR], axis=1).astype(np.float32)
        fs = sr_target

    target_len = int(round(duration * fs))
    if y.shape[0] < target_len:
        pad = target_len - y.shape[0]
        y = np.pad(y, ((0, pad), (0, 0)), mode="constant", constant_values=0.0)
    else:
        y = y[:target_len, :]

    return y, fs

def extract_late_spectrogram(y_stereo, fs, t_start=0.050, t_end=3.0,
                             n_fft=1024, hop_length=512, window="hann",
                             to_db=True, fmin=125.0, fmax=8000.0):
    start_i = int(round(t_start * fs))
    end_i   = int(round(t_end   * fs))
    start_i = max(0, min(start_i, y_stereo.shape[0]))
    end_i   = max(0, min(end_i,   y_stereo.shape[0]))
    if end_i <= start_i:
        raise ValueError("Late segment empty.")

    y = y_stereo[start_i:end_i, :]

    freqs = librosa.fft_frequencies(sr=fs, n_fft=n_fft)
    idx = np.where((freqs >= fmin) & (freqs <= fmax))[0]

    feats = []
    for ch in range(2):
        sig = y[:, ch]
        stft = librosa.stft(sig, n_fft=n_fft, hop_length=hop_length, window=window, center=False)
        S = np.abs(stft)**2
        if to_db:
            S = librosa.power_to_db(S, ref=np.max)
        S = S[idx, :]
        feats.append(S.astype(np.float32))

    feat = np.stack(feats, axis=-1)
    feat = np.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)
    return feat

def pad_or_trim_feat(feat: np.ndarray, target_shape):
    F_ref, T_ref, C_ref = target_shape
    F, T, C = feat.shape
    if C != C_ref:
        raise ValueError(f"Channel mismatch: feat has {C}, ref {C_ref}")
    feat2 = feat[:min(F, F_ref), :min(T, T_ref), :]
    pad_F = max(0, F_ref - feat2.shape[0])
    pad_T = max(0, T_ref - feat2.shape[1])
    if pad_F > 0 or pad_T > 0:
        feat2 = np.pad(feat2, ((0,pad_F),(0,pad_T),(0,0)), mode="constant", constant_values=0.0)
    return feat2

# ---------- core utilities for multi-decay ----------
def energy_envelope(sig, fs, frame_ms=10):
    """
    Frame-based energy envelope (non-overlap).
    Returns times (s), energies (linear).
    """
    frame_len = int(round(frame_ms/1000.0 * fs))
    frame_len = max(16, frame_len)
    hop = frame_len

    energies = []
    for i in range(0, len(sig) - frame_len + 1, hop):
        frame = sig[i:i+frame_len]
        energies.append(np.mean(frame**2) + 1e-12)
    energies = np.asarray(energies, dtype=np.float64)

    t = np.arange(len(energies), dtype=np.float64) * (hop / fs)
    return t, energies

def linear_slope(t, y):
    """Slope of y vs t using simple linear regression; returns 0 if not enough points."""
    if len(t) < 2:
        return 0.0
    t_mean = t.mean()
    y_mean = y.mean()
    denom = np.sum((t - t_mean)**2) + 1e-12
    slope = np.sum((t - t_mean) * (y - y_mean)) / denom
    return float(slope)

def count_peaks(x, min_prominence=0.30):
    """
    Rough peak count on 1D signal x (e.g. log-energy).
    min_prominence is in 'log units' (natural log); ~0.30 ≈ 2.6 dB.
    """
    if len(x) < 3:
        return 0.0
    dx1 = x[1:-1] - x[:-2]
    dx2 = x[1:-1] - x[2:]
    is_peak = (dx1 > 0) & (dx2 > 0)

    peaks_idx = np.where(is_peak)[0] + 1
    if len(peaks_idx) == 0:
        return 0.0

    # prominence proxy: peak - min(neighbors)
    prominences = []
    for i in peaks_idx:
        left = x[i-1]
        right = x[i+1]
        prominences.append(x[i] - min(left, right))
    prominences = np.asarray(prominences, dtype=np.float64)

    return float(np.sum(prominences >= min_prominence))

# =========================
# FEATURE EXTRACTION (EARLY + LATE + MULTI-DECAY)
# =========================
def extract_early_and_late_features_with_multidecay(
    y_stereo,
    fs,
    t_early_end=0.050,
    t_late_start=0.050,
    t_total=3.0,
    env_frame_ms=10
):

    nE = int(round(t_early_end * fs))
    nS = int(round(t_late_start * fs))
    nT = int(round(t_total * fs))

    nE = max(8, min(nE, y_stereo.shape[0]))
    nS = max(0, min(nS, y_stereo.shape[0]))
    nT = max(1, min(nT, y_stereo.shape[0]))

    yE = y_stereo[:nE, :]
    yL = y_stereo[nS:nT, :]

    feats = []

    # -------------------------
    # A) EARLY FEATURES
    # -------------------------
    for ch in [0, 1]:
        sigE = yE[:, ch]
        sigT = y_stereo[:, ch]

        rmsE = np.sqrt(np.mean(sigE**2) + 1e-12)
        rmsT = np.sqrt(np.mean(sigT**2) + 1e-12)
        eRatio = safe_div(np.sum(sigE**2), np.sum(sigT**2))

        centroid  = librosa.feature.spectral_centroid(y=sigE, sr=fs).mean()
        bandwidth = librosa.feature.spectral_bandwidth(y=sigE, sr=fs).mean()
        rolloff   = librosa.feature.spectral_rolloff(y=sigE, sr=fs, roll_percent=0.85).mean()
        zcr       = librosa.feature.zero_crossing_rate(sigE).mean()

        feats += [
            safe_log(rmsE), safe_log(rmsT),
            safe_log(eRatio),
            safe_log(centroid), safe_log(bandwidth), safe_log(rolloff),
            zcr
        ]

    rmsE_L = np.sqrt(np.mean(yE[:,0]**2) + 1e-12)
    rmsE_R = np.sqrt(np.mean(yE[:,1]**2) + 1e-12)
    ildE = 20.0 * np.log10(safe_div(rmsE_L, rmsE_R))

    rmsT_L = np.sqrt(np.mean(y_stereo[:,0]**2) + 1e-12)
    rmsT_R = np.sqrt(np.mean(y_stereo[:,1]**2) + 1e-12)
    ildT = 20.0 * np.log10(safe_div(rmsT_L, rmsT_R))

    feats += [ildE, ildT]

    # -------------------------
    # B) LATE BASELINE FEATURES
    # -------------------------
    late_feats = []

    overall_slopes = []
    late_over_early_list = []

    for ch in [0, 1]:
        sigE = yE[:, ch]
        sigL = yL[:, ch]

        eE = np.sum(sigE**2) + 1e-12
        eL = np.sum(sigL**2) + 1e-12

        late_over_early = safe_log(eL / eE)
        late_over_early_list.append(float(late_over_early))

        # overall decay slope on log-energy envelope (late segment)
        t_env, E_env = energy_envelope(sigL, fs, frame_ms=env_frame_ms)
        logE = np.log(E_env)
        slope_all = linear_slope(t_env, logE)
        overall_slopes.append(float(slope_all))

        late_feats += [late_over_early, slope_all]

    # ILD on late segment
    rmsL_L = np.sqrt(np.mean(yL[:,0]**2) + 1e-12)
    rmsL_R = np.sqrt(np.mean(yL[:,1]**2) + 1e-12)
    ildL = 20.0 * np.log10(safe_div(rmsL_L, rmsL_R))
    late_feats += [ildL] 

    # -------------------------
    # C) MULTI-DECAY EXTRA FEATURES (6)
    # -------------------------
    sigL_mean = 0.5 * (yL[:,0] + yL[:,1])
    t_env, E_env = energy_envelope(sigL_mean, fs, frame_ms=env_frame_ms)
    logE = np.log(E_env)

    # Piecewise slopes:
    # late segment starts at 50 ms (absolute); within late, define windows:
    # 50–120 ms  => 0.00–0.07 s (relative to late start)
    # 120–400 ms => 0.07–0.35 s
    t1_end = 0.07
    t2_end = 0.35

    idx1 = np.where(t_env <= t1_end)[0]
    idx2 = np.where((t_env > t1_end) & (t_env <= t2_end))[0]

    slope1 = linear_slope(t_env[idx1], logE[idx1]) if len(idx1) >= 2 else 0.0
    slope2 = linear_slope(t_env[idx2], logE[idx2]) if len(idx2) >= 2 else 0.0
    delta_slope = float(slope2 - slope1)

    # Energy Rebound Index:
    # reference at ~80 ms absolute => 0.03 s into late segment.
    t_ref = 0.03
    if len(t_env) > 0:
        i_ref = int(np.argmin(np.abs(t_env - t_ref)))
        E_ref = E_env[i_ref]
        E_after_max = np.max(E_env[i_ref:]) if i_ref < len(E_env) else E_ref
        rebound = float(safe_div(E_after_max, E_ref))
    else:
        rebound = 1.0

    # Variability of decay (std of log-energy)
    std_logE = float(np.std(logE)) if len(logE) > 1 else 0.0

    # Peak count on log-energy (captures multi-burst / secondary clusters)
    n_peaks = float(count_peaks(logE, min_prominence=0.30))

    multi_feats = [slope1, slope2, delta_slope, safe_log(rebound), std_logE, n_peaks]

    all_feats = np.asarray(feats + late_feats + multi_feats, dtype=np.float32)
    return all_feats

# =========================
# LOAD DATASET
# =========================
def load_splits_multidecay(
    data_root,
    sr_target=22050,
    duration_total=3.0,
    t_early_end=0.050,
    t_late_start=0.050,
    n_fft=1024,
    hop_length=512,
    fmin=125.0,
    fmax=8000.0,
    to_db=True,
    shuffle_data=True,
    random_state=42,
    verbose=True
):
    P_by_split = {s: [] for s in SPLITS}   

    Xspec_by_split = {s: [] for s in SPLITS}
    Xfeat_by_split = {s: [] for s in SPLITS}
    y_by_split     = {s: [] for s in SPLITS}

    shape_ref = None
    skipped = {s: 0 for s in SPLITS}
    total   = {s: 0 for s in SPLITS}

    for split_name in SPLITS:
        for class_name, label in classes.items():
            split_dir = os.path.join(data_root, class_name, split_name)
            wav_files = list_wavs(split_dir)
            if verbose:
                print(f"Loading {len(wav_files)} files from {split_dir} (label={label})")

            for fpath in wav_files:
                total[split_name] += 1
                try:
                    y_st, fs = load_audio_stereo(fpath, sr_target=sr_target, duration=duration_total)

                    v_feat = extract_early_and_late_features_with_multidecay(
                        y_st, fs,
                        t_early_end=t_early_end,
                        t_late_start=t_late_start,
                        t_total=duration_total,
                        env_frame_ms=10
                    )

                    spec = extract_late_spectrogram(
                        y_st, fs,
                        t_start=t_late_start, t_end=duration_total,
                        n_fft=n_fft, hop_length=hop_length,
                        to_db=to_db, fmin=fmin, fmax=fmax
                    )

                    if shape_ref is None:
                        shape_ref = spec.shape
                        if verbose:
                            print("Reference spectrogram shape:", shape_ref)
                            print("Feature dim:", v_feat.shape)

                    if spec.shape != shape_ref:
                        spec = pad_or_trim_feat(spec, shape_ref)

                    Xspec_by_split[split_name].append(spec.astype(np.float32))
                    Xfeat_by_split[split_name].append(v_feat.astype(np.float32))
                    y_by_split[split_name].append(int(label))

                    P_by_split[split_name].append(fpath)  

                except Exception as e:
                    skipped[split_name] += 1
                    if verbose:
                        print(f"[ERROR] {fpath} -> {e}")

    def finalize(split_name):
        Xs = np.asarray(Xspec_by_split[split_name], dtype=np.float32)
        Xf = np.asarray(Xfeat_by_split[split_name], dtype=np.float32)
        y  = np.asarray(y_by_split[split_name], dtype=np.int32)

        P = np.array(P_by_split[split_name], dtype=object)

        if shuffle_data:
            Xs, Xf, y, P = shuffle(Xs, Xf, y, P, random_state=random_state)

        if verbose:
            print(f"\nSplit='{split_name}': total={total[split_name]}, used={len(y)}, skipped={skipped[split_name]}")
            print("  Xspec:", Xs.shape)
            print("  Xfeat:", Xf.shape)
            print("  y    :", y.shape)

        if len(y) == 0:
            raise ValueError(f"Tidak ada data valid untuk split '{split_name}'.")
        return (Xs, Xf), y, P

    (Xs_tr, Xf_tr), y_tr, p_tr = finalize("training")
    (Xs_va, Xf_va), y_va, p_va = finalize("validation")
    (Xs_te, Xf_te), y_te, p_te = finalize("testing")

    meta = {"spec_shape": shape_ref, "feat_dim": Xf_tr.shape[1]}
    return (Xs_tr, Xf_tr, y_tr, p_tr), (Xs_va, Xf_va, y_va, p_va), (Xs_te, Xf_te, y_te, p_te), meta

# =========================
# LOAD
# =========================
(train_pack, val_pack, test_pack, meta) = load_splits_multidecay(
    data_root=DATA_ROOT,
    sr_target=SR_TARGET,
    duration_total=T_TOTAL,
    t_early_end=T_EARLY_END,
    t_late_start=T_LATE_START,
    n_fft=N_FFT,
    hop_length=HOP,
    fmin=FMIN,
    fmax=FMAX,
    to_db=TO_DB,
    verbose=True
)

Xspec_train, Xfeat_train, y_train, p_train = train_pack
Xspec_val,   Xfeat_val,   y_val,   p_val   = val_pack
Xspec_test,  Xfeat_test,  y_test,  p_test  = test_pack

print("\nMeta:", meta)
spec_input_shape = Xspec_train.shape[1:]
feat_input_shape = (Xfeat_train.shape[1],)

# =========================
# NORMALIZATION
# =========================
spec_norm = layers.Normalization(axis=None)
feat_norm = layers.Normalization(axis=-1)
spec_norm.adapt(Xspec_train)
feat_norm.adapt(Xfeat_train)

# =========================
# MODEL + TUNER
# =========================
SEED = 42
tf.keras.utils.set_random_seed(SEED)
np.random.seed(SEED)

def build_model(hp: kt.HyperParameters):
    lr  = hp.Choice("lr", [1e-4, 3e-4, 1e-3])
    l2w = hp.Choice("l2w", [0.0, 1e-4, 1e-3])
    reg = regularizers.l2(l2w) if l2w > 0 else None

    f1 = hp.Choice("filters1", [8, 16, 32])
    f2 = hp.Choice("filters2", [16, 32, 64])
    f3 = hp.Choice("filters3", [32, 64, 128])

    k1 = hp.Choice("k1", [3, 5])
    k2 = hp.Choice("k2", [3, 5])
    k3 = hp.Choice("k3", [3, 5])

    pool_f1 = hp.Choice("pool_f1", [2, 4])
    pool_t1 = hp.Choice("pool_t1", [2, 4])
    pool_f2 = hp.Choice("pool_f2", [2, 4])
    pool_t2 = hp.Choice("pool_t2", [2, 4])
    pool_f3 = hp.Choice("pool_f3", [2, 4])
    pool_t3 = hp.Choice("pool_t3", [2, 4])

    drop_cnn = hp.Float("drop_cnn", 0.0, 0.5, step=0.1)

    feat_u1 = hp.Choice("feat_u1", [16, 32, 64])
    feat_u2 = hp.Choice("feat_u2", [8, 16, 32])
    drop_feat = hp.Float("drop_feat", 0.0, 0.5, step=0.1)

    head_u = hp.Choice("head_u", [32, 64, 128])
    drop_head = hp.Float("drop_head", 0.0, 0.6, step=0.1)

    inp_spec = layers.Input(shape=spec_input_shape, dtype=tf.float32, name="spec_in")
    inp_feat = layers.Input(shape=feat_input_shape, dtype=tf.float32, name="feat_in")

    x = spec_norm(inp_spec)
    x = layers.Conv2D(f1, (k1,k1), padding="same", activation="relu", kernel_regularizer=reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((pool_f1, pool_t1))(x)
    x = layers.Dropout(drop_cnn)(x)

    x = layers.Conv2D(f2, (k2,k2), padding="same", activation="relu", kernel_regularizer=reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((pool_f2, pool_t2))(x)
    x = layers.Dropout(drop_cnn)(x)

    x = layers.Conv2D(f3, (k3,k3), padding="same", activation="relu", kernel_regularizer=reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((pool_f3, pool_t3))(x)
    x = layers.Dropout(drop_cnn)(x)

    x = layers.GlobalAveragePooling2D()(x)

    f = feat_norm(inp_feat)
    f = layers.Dense(feat_u1, activation="relu")(f)
    f = layers.Dropout(drop_feat)(f)
    f = layers.Dense(feat_u2, activation="relu")(f)

    z = layers.Concatenate()([x, f])
    z = layers.Dense(head_u, activation="relu")(z)
    z = layers.Dropout(drop_head)(z)
    out = layers.Dense(2, activation="softmax")(z)

    model = models.Model([inp_spec, inp_feat], out)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )
    return model

class MyBayesTuner(kt.BayesianOptimization):
    def run_trial(self, trial, *args, **kwargs):
        hp = trial.hyperparameters
        kwargs["batch_size"] = hp.Choice("batch_size", [8, 16, 32])
        return super().run_trial(trial, *args, **kwargs)

tuner = MyBayesTuner(
    hypermodel=build_model,
    objective=kt.Objective("val_loss", direction="min"),
    max_trials=20,
    executions_per_trial=1,
    directory="tuner_logs_multidecay",
    project_name="dualbranch_multidecay",
    overwrite=True,
    num_initial_points=6
)

tuner_early = EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True, mode="min", verbose=1)
tuner_reduce = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6, mode="min", verbose=1)

tuner.search(
    [Xspec_train, Xfeat_train], y_train,
    validation_data=([Xspec_val, Xfeat_val], y_val),
    epochs=60,
    callbacks=[tuner_early, tuner_reduce],
    verbose=1
)

# =========================
# CHOOSE TOP-K HYPERPARAMETER, RETRAIN, EVAL TEST, RANKING
# =========================
TOPK = 5
top_hps = tuner.get_best_hyperparameters(num_trials=TOPK)

# Class weights (dipakai konsisten untuk semua kandidat)
n0 = int(np.sum(y_train == 0))
n1 = int(np.sum(y_train == 1))
w0 = (n0 + n1) / (2.0 * n0 + 1e-12)
w1 = (n0 + n1) / (2.0 * n1 + 1e-12)
class_weight = {0: float(w0 * 1.3), 1: float(w1 * 1.0)}
print("\nClass weight used:", class_weight)

results = [] 

for rank, hp in enumerate(top_hps, start=1):
    print("\n" + "="*70)
    print(f"[CANDIDATE #{rank}] Hyperparameters:")
    for k in sorted(hp.values.keys()):
        print(f"  {k}: {hp.get(k)}")

    tf.keras.backend.clear_session()
    tf.keras.utils.set_random_seed(SEED)
    np.random.seed(SEED)

    model_k = build_model(hp)
    bs = hp.get("batch_size")

    ckpt_path = f"best_candidate_{rank}.weights.h5"
    ckpt = ModelCheckpoint(
        ckpt_path,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=True,
        mode="min",
        verbose=1
    )

    es = EarlyStopping(
        monitor="val_loss",
        patience=7,
        min_delta=1e-4,
        restore_best_weights=False,
        mode="min",
        verbose=1
    )

    rl = ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=2,
        min_lr=1e-6,
        mode="min",
        verbose=1
    )

    history_k = model_k.fit(
        [Xspec_train, Xfeat_train], y_train,
        validation_data=([Xspec_val, Xfeat_val], y_val),
        epochs=100,
        batch_size=bs,
        callbacks=[es, rl, ckpt],
        class_weight=class_weight,
        verbose=1
    )

    # Load best weights by val_loss
    model_k.load_weights(ckpt_path)

    # Evaluate on test
    test_loss_k, test_acc_k = model_k.evaluate([Xspec_test, Xfeat_test], y_test, verbose=0)
    print(f"[CANDIDATE #{rank}] TEST acc={test_acc_k:.4f} | loss={test_loss_k:.4f}")

    results.append({
        "rank_from_tuner": rank,
        "test_acc": float(test_acc_k),
        "test_loss": float(test_loss_k),
        "hp": hp,
        "ckpt_path": ckpt_path,
    })

# Ranking:
best_by_loss = sorted(results, key=lambda d: d["test_loss"])[0]
best_by_acc  = sorted(results, key=lambda d: d["test_acc"], reverse=True)[0]

print("\n" + "="*70)
print("BEST by TEST LOSS (min):")
print(f"  candidate_from_tuner: #{best_by_loss['rank_from_tuner']}")
print(f"  test_acc:  {best_by_loss['test_acc']:.4f}")
print(f"  test_loss: {best_by_loss['test_loss']:.4f}")

print("\nBEST by TEST ACC (max):")
print(f"  candidate_from_tuner: #{best_by_acc['rank_from_tuner']}")
print(f"  test_acc:  {best_by_acc['test_acc']:.4f}")
print(f"  test_loss: {best_by_acc['test_loss']:.4f}")


chosen = best_by_loss

# Rebuild & load BEST model
tf.keras.backend.clear_session()
tf.keras.utils.set_random_seed(SEED)
np.random.seed(SEED)

best_model = build_model(chosen["hp"])
best_model.load_weights(chosen["ckpt_path"])

print("\nUsing chosen candidate:", chosen["rank_from_tuner"])
print("Checkpoint:", chosen["ckpt_path"])


# =========================
# EVALUATION (BEST MODEL)
# =========================
test_loss, test_acc = best_model.evaluate([Xspec_test, Xfeat_test], y_test, verbose=0)
print(f"\n[FINAL] Test accuracy: {test_acc:.4f} | loss: {test_loss:.4f}")

y_prob = best_model.predict([Xspec_test, Xfeat_test], verbose=0)
y_pred = np.argmax(y_prob, axis=1)

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Low","High"], yticklabels=["Low","High"])
plt.title("Confusion Matrix - Best HP (Top-K Selection)")
plt.xlabel("Predicted"); plt.ylabel("True")
plt.tight_layout(); plt.show()

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=["Low","High"], digits=4))

# ============================================================
# LIST MISCLASSIFIED TEST FILES
# ============================================================
final_pred = y_pred
final_prob = y_prob
wrong_idx = np.where(final_pred != y_test)[0]

print("\n================ MISCLASSIFIED TEST FILES ================")
print("Total misclassified:", len(wrong_idx))

rows = []
for i in wrong_idx:
    true_label = "High" if int(y_test[i]) == 1 else "Low"
    pred_label = "High" if int(final_pred[i]) == 1 else "Low"
    prob_low  = float(final_prob[i, 0])
    prob_high = float(final_prob[i, 1])
    fpath = p_test[i] if "p_test" in globals() else f"(path not available) idx={i}"

    rows.append([fpath, true_label, pred_label, prob_low, prob_high])

    print(f"- {fpath}")
    print(f"  True={true_label} | Pred={pred_label} | P(Low)={prob_low:.3f} P(High)={prob_high:.3f}")

## Leclere Model (2015)

In [None]:
# ============================================================
# FULL CODE
# "Speech intelligibility prediction in reverberation: Towards an integrated model of speech transmission, spatial unmasking, and binaural de-reverberation"
# Leclere et al., 2015
# ============================================================

import os
import glob
import time
import numpy as np
import soundfile as sf

from scipy.signal import butter, sosfiltfilt
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

# =========================
# CONFIG
# =========================
DATA_ROOT = r"/content/drive/MyDrive/DATASET V3"  # <-- GANTI
CLASSES = {"Low Intelligibility": 0, "High Intelligibility": 1}
SPLITS = ["training", "validation", "testing"]

# Leclère "room-independent" example parameters (paper uses ELL=30ms, DD=25ms)
ELL_MS = 30.0
DD_MS  = 25.0

BANDS_HZ = [
    (125, 250),
    (250, 500),
    (500, 1000),
    (1000, 2000),
    (2000, 4000),
    (4000, 8000),
]

EPS = 1e-12


# =========================
# 1) Direct sound detection (Leclère rule)
# =========================
def find_direct_index_leclere(x: np.ndarray) -> int:
    """
    'first sample which is at least 25% greater than all previous samples'
    """
    a = np.abs(x.astype(np.float64))
    prev_max = a[0]
    for i in range(1, len(a)):
        if a[i] >= 1.25 * prev_max:
            return i
        if a[i] > prev_max:
            prev_max = a[i]
    return 0


def direct_index_stereo(y: np.ndarray) -> int:
    if y.ndim == 1:
        return find_direct_index_leclere(y)
    else:
        idxL = find_direct_index_leclere(y[:, 0])
        idxR = find_direct_index_leclere(y[:, 1])
        return int(min(idxL, idxR))


# =========================
# 2) Early/Late complementary windows (Leclère)
# =========================
def make_linear_complementary_windows(n: int, sr: int, direct_idx: int, ell_ms: float, dd_ms: float):
    ell = int(round((ell_ms / 1000.0) * sr))
    dd  = int(round((dd_ms  / 1000.0) * sr))

    start = direct_idx
    flat_end = min(n, start + ell)
    decay_end = min(n, flat_end + dd)

    w_early = np.zeros(n, dtype=np.float64)

    # flat = 1
    if flat_end > 0:
        w_early[:flat_end] = 1.0

    # decay linear 1->0
    if decay_end > flat_end and dd > 0:
        k = decay_end - flat_end
        w_early[flat_end:decay_end] = np.linspace(1.0, 0.0, k, endpoint=False)

    w_late = 1.0 - w_early
    return w_early, w_late


# =========================
# 3) Bandpass + energy
# =========================
def bandpass_sos(low_hz: float, high_hz: float, sr: int, order: int = 4):
    nyq = 0.5 * sr
    low = max(low_hz / nyq, 1e-6)
    high = min(high_hz / nyq, 0.999999)
    if high <= low:
        return None
    return butter(order, [low, high], btype="band", output="sos")


def band_energy(x: np.ndarray, sos) -> float:
    if sos is not None:
        xf = sosfiltfilt(sos, x.astype(np.float64))
    else:
        xf = x.astype(np.float64)
    return float(np.sum(xf * xf) + EPS)


# =========================
# 4) Leclère-style U/D features (better-ear, per band)
# =========================
def leclere_ud_features(wav_path: str, ell_ms: float = ELL_MS, dd_ms: float = DD_MS) -> np.ndarray:
    y, sr = sf.read(wav_path, always_2d=True)  # (n, ch)
    if y.shape[1] == 1:
        y = np.repeat(y, 2, axis=1)

    n = y.shape[0]
    d0 = direct_index_stereo(y)

    w_early, w_late = make_linear_complementary_windows(
        n=n, sr=sr, direct_idx=d0, ell_ms=ell_ms, dd_ms=dd_ms
    )

    yE = y * w_early[:, None]
    yL = y * w_late[:, None]

    feats = []
    for (f1, f2) in BANDS_HZ:
        sos = bandpass_sos(f1, f2, sr)

        eE_L = band_energy(yE[:, 0], sos)
        eL_L = band_energy(yL[:, 0], sos)

        eE_R = band_energy(yE[:, 1], sos)
        eL_R = band_energy(yL[:, 1], sos)

        ud_L = 10.0 * np.log10(eE_L / eL_L)
        ud_R = 10.0 * np.log10(eE_R / eL_R)

        ud_be = max(ud_L, ud_R)    # better-ear
        ud_diff = ud_L - ud_R      # asymmetry

        feats.extend([ud_be, ud_diff])

    return np.asarray(feats, dtype=np.float32)


# =========================
# 5) Load dataset
# =========================
def list_wavs(folder):
    exts = ["*.wav", "*.WAV", "*.Wav"]
    files = []
    for e in exts:
        files.extend(glob.glob(os.path.join(folder, e)))
    return sorted(files)


def collect_split(split: str, measure_time: bool = False):
    X, y, paths = [], [], []
    feat_times = []

    for cname, label in CLASSES.items():
        folder = os.path.join(DATA_ROOT, cname, split)
        files = list_wavs(folder)
        for fp in files:
            if measure_time:
                t0 = time.perf_counter()
                feat = leclere_ud_features(fp)
                t1 = time.perf_counter()
                feat_times.append(t1 - t0)
            else:
                feat = leclere_ud_features(fp)

            X.append(feat)
            y.append(label)
            paths.append(fp)

    X = np.vstack(X) if len(X) else np.zeros((0, len(BANDS_HZ) * 2), dtype=np.float32)
    y = np.asarray(y, dtype=np.int64)

    timing = None
    if measure_time:
        ft = np.asarray(feat_times, dtype=np.float64)
        timing = {
            "n_files": int(len(ft)),
            "total_s": float(ft.sum()),
            "mean_ms": float(ft.mean() * 1000.0) if ft.size else 0.0,
            "p50_ms": float(np.percentile(ft, 50) * 1000.0) if ft.size else 0.0,
            "p95_ms": float(np.percentile(ft, 95) * 1000.0) if ft.size else 0.0,
        }

    return X, y, paths, timing


def _print_timing_block(title: str, timing: dict):
    if timing is None:
        return
    n = timing["n_files"]
    total_s = timing["total_s"]
    mean_ms = timing["mean_ms"]
    p50_ms = timing["p50_ms"]
    p95_ms = timing["p95_ms"]
    thr = (n / total_s) if total_s > 0 else 0.0

    print(f"\n=== {title} ===")
    print(f"Files: {n}")
    print(f"Total time: {total_s:.3f} s")
    print(f"Mean/file: {mean_ms:.3f} ms | P50: {p50_ms:.3f} ms | P95: {p95_ms:.3f} ms")
    print(f"Throughput: {thr:.2f} files/s")


# =========================
# 6) Train + validate threshold + test (+ running time)
# =========================
def train_and_evaluate(measure_runtime: bool = True):
    # Training/validation extraction
    Xtr, ytr, _, _ = collect_split("training", measure_time=False)
    Xva, yva, _, _ = collect_split("validation", measure_time=False)

    # Testing extraction
    Xte, yte, pte, te_feat_timing = collect_split("testing", measure_time=measure_runtime)

    print("\n=== SPLIT COUNTS (TEST) ===")
    print("Xte shape:", Xte.shape)
    print("yte unique:", np.unique(yte, return_counts=True))

    if Xtr.shape[0] == 0:
        raise RuntimeError("Training set empty. Check the file/forder")

    # Model numerik
    clf = Pipeline([
        ("scaler", StandardScaler()),
        ("lr", LogisticRegression(max_iter=2000, class_weight="balanced"))
    ])

    clf.fit(Xtr, ytr)

    # --- validation: search best threshold
    va_prob = clf.predict_proba(Xva)[:, 1] if Xva.shape[0] else None
    best_thr = 0.5
    if va_prob is not None and len(np.unique(yva)) > 1:
        thrs = np.linspace(0.1, 0.9, 81)
        best_f1, best_thr = -1, 0.5
        for t in thrs:
            pred = (va_prob >= t).astype(int)
            tp = np.sum((pred == 1) & (yva == 1))
            fp = np.sum((pred == 1) & (yva == 0))
            fn = np.sum((pred == 0) & (yva == 1))
            prec = tp / (tp + fp + EPS)
            rec  = tp / (tp + fn + EPS)
            f1   = 2 * prec * rec / (prec + rec + EPS)
            if f1 > best_f1:
                best_f1, best_thr = f1, t

    # =========================
    # RUNNING TIME (INFERENCE) on TEST SET
    # =========================
    te_infer_timing = None
    te_total_timing = None

    if measure_runtime and Xte.shape[0] > 0:
        _ = clf.predict_proba(Xte[:min(5, len(Xte))])[:, 1]

        infer_times = []
        
        for i in range(Xte.shape[0]):
            xi = Xte[i:i+1]
            t0 = time.perf_counter()
            _ = clf.predict_proba(xi)[:, 1]
            t1 = time.perf_counter()
            infer_times.append(t1 - t0)

        it = np.asarray(infer_times, dtype=np.float64)
        te_infer_timing = {
            "n_files": int(len(it)),
            "total_s": float(it.sum()),
            "mean_ms": float(it.mean() * 1000.0),
            "p50_ms": float(np.percentile(it, 50) * 1000.0),
            "p95_ms": float(np.percentile(it, 95) * 1000.0),
        }

        # End-to-end = feature extraction + inference (per test split)
        te_total_timing = {
            "n_files": int(Xte.shape[0]),
            "total_s": float((te_feat_timing["total_s"] if te_feat_timing else 0.0) + it.sum()),
            "mean_ms": float(((te_feat_timing["total_s"] if te_feat_timing else 0.0) + it.sum()) / Xte.shape[0] * 1000.0),
            "p50_ms": float((te_feat_timing["p50_ms"] if te_feat_timing else 0.0) + np.percentile(it, 50) * 1000.0),
            "p95_ms": float((te_feat_timing["p95_ms"] if te_feat_timing else 0.0) + np.percentile(it, 95) * 1000.0),
        }

    # =========================
    # TEST (predictions)
    # =========================
    te_prob = clf.predict_proba(Xte)[:, 1]
    te_pred = (te_prob >= best_thr).astype(int)

    print("\nBest threshold from validation:", best_thr)
    print("\n=== TEST REPORT ===")
    print(classification_report(
        yte, te_pred,
        labels=[0, 1],
        target_names=["Low", "High"],
        zero_division=0
    ))
    print("Confusion matrix:\n", confusion_matrix(yte, te_pred, labels=[0, 1]))
    if len(np.unique(yte)) > 1:
        print("ROC-AUC:", roc_auc_score(yte, te_prob))

    # =========================
    # PRINT TIMING SUMMARY
    # =========================
    if measure_runtime:
        _print_timing_block("TEST FEATURE EXTRACTION TIME", te_feat_timing)
        _print_timing_block("TEST INFERENCE TIME (predict_proba)", te_infer_timing)
        _print_timing_block("TEST END-TO-END TIME (feature + inference)", te_total_timing)

    # List misclassified
    wrong = np.where(te_pred != yte)[0]
    print("\n=== MISCLASSIFIED TEST FILES ===")
    print("Total misclassified:", len(wrong))
    for i in wrong[:50]:
        print(f"- {pte[i]} | true={yte[i]} pred={te_pred[i]} pHigh={te_prob[i]:.3f}")

    return clf


if __name__ == "__main__":
    model = train_and_evaluate(measure_runtime=True)


## Cardiff Model (2013)

In [None]:
# ============================================================
# FULL CODE
# "Predicting binaural speech intelligibility in architectural acoustics"
# Culling et al., 2013
# ============================================================

import os, glob, time
import numpy as np
import soundfile as sf

import scipy.signal as sps
from scipy.signal import gammatone, lfilter, fftconvolve

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

from sklearn.impute import SimpleImputer


# =========================
# CONFIG
# =========================
DATA_ROOT = r"/content/drive/MyDrive/DATASET V3"  
CLASSES = {"Low Intelligibility": 0, "High Intelligibility": 1}

# Cardiff method settings
F_MIN = 20.0
F_MAX = 10000.0
ERB_STEP = 0.5

XCORR_WIN_MS = 100.0
LAG_MS = 5.0
EPOCHS_S = [0.5, 1.0, 1.5, 2.0]
NOISE_DUR_S = 4.3

# Eq constants
SIGMA_D = 0.000105
EPS = 1e-12

SPEECH_SHAPED_NOISE_WAV = None

# Safety limits
MAX_ABS_AFTER_CONV = 0.99   # normalize after convolution
COH_CLIP = 0.999            # clamp coherence
RMS_FLOOR = 1e-9            # floor RMS to avoid log issues


# =========================
# ERB helpers
# =========================
def hz_to_erbnum(f_hz: float) -> float:
    return 21.4 * np.log10(4.37e-3 * f_hz + 1.0)

def erbnum_to_hz(erb: float) -> float:
    return (10**(erb / 21.4) - 1.0) / 4.37e-3

def erb_center_frequencies(fmin=F_MIN, fmax=F_MAX, step_erb=ERB_STEP) -> np.ndarray:
    e1 = hz_to_erbnum(fmin)
    e2 = hz_to_erbnum(fmax)
    erbs = np.arange(e1, e2 + 1e-9, step_erb)
    return erbnum_to_hz(erbs).astype(np.float64)


# =========================
# SII weighting proxy
# =========================
def sii_weight_proxy(fc_hz: np.ndarray) -> np.ndarray:
    fc = np.asarray(fc_hz, dtype=np.float64)
    w = np.zeros_like(fc)

    lo, hi = 400.0, 4400.0
    w[(fc >= lo) & (fc <= hi)] = 1.0

    low2 = 100.0
    m = (fc >= low2) & (fc < lo)
    if np.any(m):
        w[m] = (np.log(fc[m]) - np.log(low2)) / (np.log(lo) - np.log(low2) + EPS)

    hi2 = 8000.0
    m = (fc > hi) & (fc <= hi2)
    if np.any(m):
        w[m] = (np.log(hi2) - np.log(fc[m])) / (np.log(hi2) - np.log(hi) + EPS)

    w = np.clip(w, 0.0, 1.0)
    s = w.sum()
    if s > 0:
        w /= s
    return w


# =========================
# Noise generator / loader
# =========================
def get_noise(sr: int, dur_s: float, path: str | None = SPEECH_SHAPED_NOISE_WAV) -> np.ndarray:
    n = int(round(dur_s * sr))
    if path is None:
        rng = np.random.default_rng(0)
        x = rng.standard_normal(n).astype(np.float64)
        # normalize noise RMS
        x = x / (np.sqrt(np.mean(x*x) + EPS) + EPS)
        return x

    x, sr2 = sf.read(path, always_2d=False)
    x = x.astype(np.float64)
    if sr2 != sr:
        x = sps.resample_poly(x, up=sr, down=sr2).astype(np.float64)

    if len(x) < n:
        reps = int(np.ceil(n / len(x)))
        x = np.tile(x, reps)
    x = x[:n]
    x = x / (np.sqrt(np.mean(x*x) + EPS) + EPS)
    return x


# =========================
# Numeric-safe helpers
# =========================
def safe_rms_db(x: np.ndarray) -> float:
    x = x.astype(np.float64)
    rms = np.sqrt(np.mean(x*x) + EPS)
    rms = max(rms, RMS_FLOOR)
    return float(20.0 * np.log10(rms + EPS))

def exp_taper_window(n: int, sr: int, win_ms: float = XCORR_WIN_MS) -> np.ndarray:
    t = np.arange(n, dtype=np.float64) / sr
    tau = (win_ms / 1000.0) / 2.0
    return np.exp(-t / (tau + EPS))

def xcorr_max_coherence_delay(xL: np.ndarray, xR: np.ndarray, sr: int, lag_ms: float = LAG_MS):
    maxlag = int(round((lag_ms / 1000.0) * sr))
    xL = xL.astype(np.float64)
    xR = xR.astype(np.float64)

    # safe denom: avoid overflow by using norms separately (no multiplication overflow)
    nL = np.sqrt(np.sum(xL*xL) + EPS)
    nR = np.sqrt(np.sum(xR*xR) + EPS)
    denom = (nL * nR) + EPS

    best = -1e9
    best_lag = 0

    for lag in range(-maxlag, maxlag + 1):
        if lag < 0:
            a = xL[-lag:]
            b = xR[:len(a)]
        elif lag > 0:
            a = xL[:-lag]
            b = xR[lag:]
        else:
            a = xL
            b = xR

        # safe dot
        c = float(np.sum(a*b) / denom)
        if c > best:
            best = c
            best_lag = lag

    # clamp to prevent weird values slightly >1 due to numeric error
    best = float(np.clip(best, -COH_CLIP, COH_CLIP))
    return best, float(best_lag) / sr

def bmld_from_params(rho: float, T_rad: float, I_rad: float, fc_hz: float) -> float:
    omega = 2.0 * np.pi * fc_hz
    k = (1.0 + rho)**2 * np.exp((omega**2) * (SIGMA_D**2))

    # num/den must be positive
    num = (k - np.cos(T_rad - I_rad))
    den = (k - rho)

    ratio = (num + EPS) / (den + EPS)
    # clamp to avoid log10 of <=0 due to numeric issues
    ratio = float(max(ratio, 1e-12))

    val = 10.0 * np.log10(ratio)
    if not np.isfinite(val) or val < 0:
        val = 0.0
    return float(val)


# =========================
# Cardiff feature extractor (stable)
# =========================
def cardiff_features_from_brir(wav_path: str) -> np.ndarray:
    brir, sr = sf.read(wav_path, always_2d=True)
    brir = brir.astype(np.float64)

    if brir.shape[1] == 1:
        brir = np.repeat(brir, 2, axis=1)

    # normalize BRIR per channel (avoid huge convolution gain)
    for ch in range(2):
        mx = np.max(np.abs(brir[:, ch])) + EPS
        brir[:, ch] = brir[:, ch] / mx

    noise = get_noise(sr, NOISE_DUR_S, SPEECH_SHAPED_NOISE_WAV)

    yL = fftconvolve(noise, brir[:, 0], mode="full")
    yR = fftconvolve(noise, brir[:, 1], mode="full")

    yL = yL[:len(noise)]
    yR = yR[:len(noise)]

    # normalize after convolution to prevent overflow in filtering/energy
    mx = max(np.max(np.abs(yL)), np.max(np.abs(yR))) + EPS
    yL = (yL / mx) * MAX_ABS_AFTER_CONV
    yR = (yR / mx) * MAX_ABS_AFTER_CONV

    fc = erb_center_frequencies(F_MIN, F_MAX, ERB_STEP)
    w_sii = sii_weight_proxy(fc)

    win_n = int(round((XCORR_WIN_MS / 1000.0) * sr))
    taper = exp_taper_window(win_n, sr, XCORR_WIN_MS)

    bmld_band = np.zeros(len(fc), dtype=np.float64)
    coh_band  = np.zeros(len(fc), dtype=np.float64)

    int_level_L = np.zeros(len(fc), dtype=np.float64)
    int_level_R = np.zeros(len(fc), dtype=np.float64)

    for i, f0 in enumerate(fc):
        # gammatone IIR filter
        b, a = gammatone(f0, "iir", fs=sr)

        fL = lfilter(b, a, yL)
        fR = lfilter(b, a, yR)

        # if filter output is non-finite, fallback to safe zero
        if not np.all(np.isfinite(fL)) or not np.all(np.isfinite(fR)):
            bmld_band[i] = 0.0
            coh_band[i] = 0.0
            int_level_L[i] = -120.0
            int_level_R[i] = -120.0
            continue

        # monaural levels (RMS dB)
        int_level_L[i] = safe_rms_db(fL)
        int_level_R[i] = safe_rms_db(fR)

        bmlds = []
        cohs  = []
        for t0 in EPOCHS_S:
            s0 = int(round(t0 * sr))
            s1 = s0 + win_n
            if s1 > len(fL):
                continue

            segL = fL[s0:s1] * taper
            segR = fR[s0:s1] * taper

            rho, delay_s = xcorr_max_coherence_delay(segL, segR, sr, LAG_MS)

            # interferer IPD
            I_rad = 2.0 * np.pi * f0 * delay_s
            # keep IPD bounded (optional safety)
            I_rad = float(np.clip(I_rad, -np.pi, np.pi))

            # target assumed diotic (IPD=0)
            T_rad = 0.0

            bmlds.append(bmld_from_params(rho=rho, T_rad=T_rad, I_rad=I_rad, fc_hz=f0))
            cohs.append(rho)

        bmld_band[i] = float(np.mean(bmlds)) if bmlds else 0.0
        coh_band[i]  = float(np.mean(cohs))  if cohs  else 0.0

    binaural_adv = float(np.sum(w_sii * bmld_band))

    # better-ear masker level (lower masker => better)
    int_level_be = np.minimum(int_level_L, int_level_R)
    # monaural term: higher is better => minus masker level
    monaural_term = float(np.sum(w_sii * (-int_level_be)))

    effective = monaural_term + binaural_adv

    low = fc <= 1500.0
    mean_coh_low  = float(np.mean(coh_band[low]))  if np.any(low) else float(np.mean(coh_band))
    mean_bmld_low = float(np.mean(bmld_band[low])) if np.any(low) else float(np.mean(bmld_band))

    feat = np.asarray([effective, binaural_adv, monaural_term, mean_coh_low, mean_bmld_low], dtype=np.float32)

    # final safety: replace any non-finite with 0
    feat[~np.isfinite(feat)] = 0.0
    return feat


# =========================
# Dataset loading + timing
# =========================
def list_wavs(folder):
    exts = ["*.wav", "*.WAV", "*.Wav"]
    files = []
    for e in exts:
        files.extend(glob.glob(os.path.join(folder, e)))
    return sorted(files)

def collect_split(split: str, measure_time: bool = False):
    X, y, paths = [], [], []
    feat_times = []

    for cname, label in CLASSES.items():
        folder = os.path.join(DATA_ROOT, cname, split)
        files = list_wavs(folder)
        for fp in files:
            if measure_time:
                t0 = time.perf_counter()
                feat = cardiff_features_from_brir(fp)
                t1 = time.perf_counter()
                feat_times.append(t1 - t0)
            else:
                feat = cardiff_features_from_brir(fp)

            X.append(feat)
            y.append(label)
            paths.append(fp)

    X = np.vstack(X) if len(X) else np.zeros((0, 5), dtype=np.float32)
    y = np.asarray(y, dtype=np.int64)

    timing = None
    if measure_time:
        ft = np.asarray(feat_times, dtype=np.float64)
        timing = {
            "n_files": int(len(ft)),
            "total_s": float(ft.sum()),
            "mean_ms": float(ft.mean() * 1000.0) if ft.size else 0.0,
            "p50_ms": float(np.percentile(ft, 50) * 1000.0) if ft.size else 0.0,
            "p95_ms": float(np.percentile(ft, 95) * 1000.0) if ft.size else 0.0,
        }

    return X, y, paths, timing

def _print_timing_block(title: str, timing: dict):
    if timing is None:
        return
    n = timing["n_files"]
    total_s = timing["total_s"]
    mean_ms = timing["mean_ms"]
    p50_ms = timing["p50_ms"]
    p95_ms = timing["p95_ms"]
    thr = (n / total_s) if total_s > 0 else 0.0

    print(f"\n=== {title} ===")
    print(f"Files: {n}")
    print(f"Total time: {total_s:.3f} s")
    print(f"Mean/file: {mean_ms:.3f} ms | P50: {p50_ms:.3f} ms | P95: {p95_ms:.3f} ms")
    print(f"Throughput: {thr:.2f} files/s")


# =========================
# Train + eval
# =========================
def train_and_evaluate(measure_runtime: bool = True):
    Xtr, ytr, _, _ = collect_split("training", measure_time=False)
    Xva, yva, _, _ = collect_split("validation", measure_time=False)
    Xte, yte, pte, te_feat_timing = collect_split("testing", measure_time=measure_runtime)

    print("\n=== SPLIT COUNTS (TEST) ===")
    print("Xte shape:", Xte.shape)
    print("yte unique:", np.unique(yte, return_counts=True))

    if Xtr.shape[0] == 0:
        raise RuntimeError("Training set empty. Check the file/folder")

    clf = Pipeline([
        ("imputer", SimpleImputer(strategy="median")),  # safety net
        ("scaler", StandardScaler()),
        ("lr", LogisticRegression(max_iter=3000, class_weight="balanced"))
    ])
    clf.fit(Xtr, ytr)

    va_prob = clf.predict_proba(Xva)[:, 1] if Xva.shape[0] else None
    best_thr = 0.5
    if va_prob is not None and len(np.unique(yva)) > 1:
        thrs = np.linspace(0.1, 0.9, 81)
        best_f1, best_thr = -1, 0.5
        for t in thrs:
            pred = (va_prob >= t).astype(int)
            tp = np.sum((pred == 1) & (yva == 1))
            fp = np.sum((pred == 1) & (yva == 0))
            fn = np.sum((pred == 0) & (yva == 1))
            prec = tp / (tp + fp + EPS)
            rec  = tp / (tp + fn + EPS)
            f1   = 2 * prec * rec / (prec + rec + EPS)
            if f1 > best_f1:
                best_f1, best_thr = f1, t

    # Inference timing (per-file)
    te_infer_timing = None
    te_total_timing = None
    if measure_runtime and Xte.shape[0] > 0:
        _ = clf.predict_proba(Xte[:min(5, len(Xte))])[:, 1]  # warmup

        infer_times = []
        for i in range(Xte.shape[0]):
            xi = Xte[i:i+1]
            t0 = time.perf_counter()
            _ = clf.predict_proba(xi)[:, 1]
            t1 = time.perf_counter()
            infer_times.append(t1 - t0)

        it = np.asarray(infer_times, dtype=np.float64)
        te_infer_timing = {
            "n_files": int(len(it)),
            "total_s": float(it.sum()),
            "mean_ms": float(it.mean() * 1000.0),
            "p50_ms": float(np.percentile(it, 50) * 1000.0),
            "p95_ms": float(np.percentile(it, 95) * 1000.0),
        }
        te_total_timing = {
            "n_files": int(Xte.shape[0]),
            "total_s": float((te_feat_timing["total_s"] if te_feat_timing else 0.0) + it.sum()),
            "mean_ms": float(((te_feat_timing["total_s"] if te_feat_timing else 0.0) + it.sum()) / Xte.shape[0] * 1000.0),
            "p50_ms": float((te_feat_timing["p50_ms"] if te_feat_timing else 0.0) + np.percentile(it, 50) * 1000.0),
            "p95_ms": float((te_feat_timing["p95_ms"] if te_feat_timing else 0.0) + np.percentile(it, 95) * 1000.0),
        }

    te_prob = clf.predict_proba(Xte)[:, 1]
    te_pred = (te_prob >= best_thr).astype(int)

    print("\nBest threshold from validation:", best_thr)
    print("\n=== TEST REPORT ===")
    print(classification_report(
        yte, te_pred,
        labels=[0, 1],
        target_names=["Low", "High"],
        zero_division=0
    ))
    print("Confusion matrix:\n", confusion_matrix(yte, te_pred, labels=[0, 1]))
    if len(np.unique(yte)) > 1:
        print("ROC-AUC:", roc_auc_score(yte, te_prob))

    if measure_runtime:
        _print_timing_block("TEST FEATURE EXTRACTION TIME (Cardiff features, stable)", te_feat_timing)
        _print_timing_block("TEST INFERENCE TIME (predict_proba)", te_infer_timing)
        _print_timing_block("TEST END-TO-END TIME (feature + inference)", te_total_timing)

    wrong = np.where(te_pred != yte)[0]
    print("\n=== MISCLASSIFIED TEST FILES ===")
    print("Total misclassified:", len(wrong))
    for i in wrong[:50]:
        print(f"- {pte[i]} | true={yte[i]} pred={te_pred[i]} pHigh={te_prob[i]:.3f}")

    return clf


if __name__ == "__main__":
    model = train_and_evaluate(measure_runtime=True)
