# SVM-based Speech/Music Classifier + Efficient Implementation (Filtering & Skipping)

**Paper implemented:** *Efficient implementation techniques of an SVM-based speech/music classifier in SMV* (Lim & Chang, 2015)

این نوت‌بوک سه بخش اصلی را پیاده‌سازی می‌کند:

1) **Baseline**: آموزش و ارزیابی یک **SVM از صفر با PyTorch (بدون sklearn)** برای تفکیک Speech vs Music روی فریم‌های صوتی  
2) **Filtering (Hierarchical)**: یک **فیلتر ساده** که بخشی از فریم‌ها را بدون اجرای SVM به‌عنوان *music* تشخیص می‌دهد  
3) **Skipping**: با تکیه بر **همبستگی بین فریم‌های متوالی**، اجرای SVM برای تعدادی از فریم‌های بعدی *skip* می‌شود  
4) **Combined**: ترکیب Filtering → Skipping (همان ایده‌ی فلوچارت مقاله)

> نکته مهم: در مقاله، ۶ ویژگی دقیقاً از داخل SMV codec استخراج می‌شوند.  
> در این نوت‌بوک برای اینکه مستقل از SMV اجرا شود، نسخه‌ی **تقریبی/معادل** از آن ویژگی‌ها را از سیگنال خام پیاده‌سازی می‌کنیم.  
> اگر شما دسترسی به ویژگی‌های داخلی SMV دارید، کافی است بخش `extract_features_for_frames` را با ویژگی‌های واقعی جایگزین کنید.

---

## پیش‌نیازها
- فایل‌های صوتی Speech و Music (wav/mp3/…)
- نصب کتابخانه‌ها: `numpy, scipy, torch, librosa, soundfile, matplotlib`

ساختار پیشنهادی دیتا:
```
data/
  speech/
    s1.wav
    s2.wav
  music/
    m1.wav
    m2.wav
```

---


In [33]:
# اگر لازم بود (روی conda env شما)، این‌ها را نصب کنید:
!pip -q install numpy scipy torch librosa soundfile matplotlib


In [34]:
import os
import glob
import time
import numpy as np

import librosa
import soundfile as sf

from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt



## 1) تنظیمات اصلی (مطابق مقاله تا حد امکان)
- نرخ نمونه‌برداری: **8kHz**
- طول فریم: **20ms**  → 160 نمونه
- همپوشانی: اینجا 50% گذاشته شده (قابل تغییر)


In [35]:

SR = 8000
FRAME_MS = 20
FRAME_LEN = int(SR * FRAME_MS / 1000)  # 160
HOP_LEN = FRAME_LEN // 2               # 50% overlap

# برچسب‌ها
LABEL_SPEECH = 1
LABEL_MUSIC = -1

DATA_DIR = "musan"   # پوشه‌ی دیتا
SPEECH_DIR = os.path.join(DATA_DIR, "speech")
MUSIC_DIR  = os.path.join(DATA_DIR, "music")



## 2) خواندن دیتا
هر فایل به فریم‌ها تبدیل می‌شود؛ سپس هر فریم یک بردار ویژگی می‌گیرد.  
برای جلوگیری از leakage، تقسیم train/test را **بر اساس فایل** انجام می‌دهیم (Group split).


In [36]:

def list_audio_files(folder: str) -> List[str]:
    exts = ("*.wav", "*.flac", "*.mp3", "*.m4a", "*.ogg")
    files = []
    for e in exts:
        files.extend(glob.glob(os.path.join(folder, "**", e), recursive=True))
    return sorted(files)

speech_files = list_audio_files(SPEECH_DIR)
music_files  = list_audio_files(MUSIC_DIR)

print("Speech files:", len(speech_files))
print("Music files :", len(music_files))
print("Example speech:", speech_files[:2])
print("Example music :", music_files[:2])


Speech files: 426
Music files : 660
Example speech: ['musan/speech/librivox/speech-librivox-0000.wav', 'musan/speech/librivox/speech-librivox-0001.wav']
Example music : ['musan/music/fma-western-art/music-fma-wa-0000.wav', 'musan/music/fma-western-art/music-fma-wa-0001.wav']



## 3) استخراج ویژگی‌ها (نسخه‌ی نزدیک به مقاله، اما از سیگنال خام)

مقاله از ۶ ویژگی استفاده می‌کند (ویژگی‌های داخلی SMV).  
اینجا به صورت تقریبی می‌سازیم:

- Energy: RMS
- Reflection coeffs: با LPC و PARCOR (تقریبی)
- Residual energy: انرژی residual از LPC
- Pitch corr: بیشینه‌ی autocorr نرمال‌شده در بازه‌ی pitch (مثلاً 80–400Hz)
- Periodicity counter: اگر pitch_corr از threshold گذشت، کانتر افزایش
- Music continuity: کانتر ساده بر اساس (energy بالا + pitch_corr پایین‌تر) برای موسیقی

> اگر شما ویژگی‌های واقعی SMV دارید، فقط همین سل را جایگزین کنید.


In [37]:

def frame_signal(y: np.ndarray, frame_len: int, hop_len: int) -> np.ndarray:
    # Returns frames: shape (num_frames, frame_len)
    if len(y) < frame_len:
        y = np.pad(y, (0, frame_len - len(y)))
    n_frames = 1 + (len(y) - frame_len) // hop_len
    frames = np.lib.stride_tricks.as_strided(
        y,
        shape=(n_frames, frame_len),
        strides=(y.strides[0] * hop_len, y.strides[0]),
        writeable=False,
    )
    return frames.copy()

def levinson_durbin(r: np.ndarray, order: int):
    # Levinson-Durbin recursion (returns LPC a, reflection k, and error e)
    a = np.zeros(order + 1)
    e = float(r[0])
    k = np.zeros(order)

    a[0] = 1.0
    if e <= 1e-12:
        return a, k, e

    for i in range(1, order + 1):
        acc = 0.0
        for j in range(1, i):
            acc += a[j] * r[i - j]
        ki = (r[i] - acc) / (e + 1e-12)
        k[i - 1] = ki

        a_new = a.copy()
        a_new[i] = ki
        for j in range(1, i):
            a_new[j] = a[j] - ki * a[i - j]
        a = a_new

        e *= (1.0 - ki * ki)
        if e <= 1e-12:
            break

    return a, k, float(e)

def autocorr_norm_peak(frame: np.ndarray, sr: int, fmin=80, fmax=400) -> float:
    frame = frame - np.mean(frame)
    denom = np.dot(frame, frame) + 1e-12
    if denom <= 1e-12:
        return 0.0
    ac = np.correlate(frame, frame, mode="full")
    ac = ac[len(ac)//2:]  # non-negative lags
    ac = ac / denom

    lag_min = int(sr / fmax)
    lag_max = int(sr / fmin)
    lag_max = min(lag_max, len(ac) - 1)
    if lag_max <= lag_min:
        return 0.0

    peak = np.max(ac[lag_min:lag_max+1])
    return float(np.clip(peak, 0.0, 1.0))

@dataclass
class RunningState:
    energy_ma: float = 0.0
    refl_ma: float = 0.0
    resid_ma: float = 0.0
    pitch_ma: float = 0.0

    cpr: int = 0
    cpr_ma: float = 0.0
    cpr_frame_count: int = 0

    cM: float = 0.0
    cM_ma: float = 0.0

def extract_features_for_frames(
    frames: np.ndarray, sr: int,
    ema_alpha: float = 0.9,
    lpc_order: int = 10,
    periodicity_thr: float = 0.65,
    cm_energy_thr: float = 0.02,
    cm_pitch_thr: float = 0.45
) -> np.ndarray:
    # Outputs feature matrix (num_frames, 6): [energy_ma, refl_ma, resid_ma, pitch_ma, cpr_ma, cM_ma]
    st = RunningState()
    feats = []

    for fr in frames:
        rms = float(np.sqrt(np.mean(fr**2) + 1e-12))

        fr_z = fr - np.mean(fr)
        r = np.correlate(fr_z, fr_z, mode="full")[len(fr_z)-1:]
        r = r[:lpc_order+1]

        a, k, e = levinson_durbin(r, lpc_order)
        refl_summary = float(np.mean(np.abs(k))) if len(k) else 0.0
        resid_energy = float(e / (r[0] + 1e-12)) if r[0] > 0 else 0.0

        pitch_corr = autocorr_norm_peak(fr, sr)

        st.energy_ma = ema_alpha * st.energy_ma + (1 - ema_alpha) * rms
        st.refl_ma   = ema_alpha * st.refl_ma   + (1 - ema_alpha) * refl_summary
        st.resid_ma  = ema_alpha * st.resid_ma  + (1 - ema_alpha) * resid_energy
        st.pitch_ma  = ema_alpha * st.pitch_ma  + (1 - ema_alpha) * pitch_corr

        periodicity_flag = 1 if pitch_corr >= periodicity_thr else 0
        st.cpr += periodicity_flag
        st.cpr_frame_count += 1
        if st.cpr_frame_count >= 32:
            st.cpr = 0
            st.cpr_frame_count = 0

        st.cpr_ma = ema_alpha * st.cpr_ma + (1 - ema_alpha) * st.cpr

        # heuristic "music continuity"
        if (rms >= cm_energy_thr) and (pitch_corr <= cm_pitch_thr):
            st.cM = min(st.cM + 5, 400)
        else:
            st.cM = max(st.cM - 2, 0)

        st.cM_ma = 0.9 * st.cM_ma + 0.1 * st.cM
        feats.append([st.energy_ma, st.refl_ma, st.resid_ma, st.pitch_ma, st.cpr_ma, st.cM_ma])

    return np.asarray(feats, dtype=np.float32)



## 4) ساخت دیتاست فریم‌ها
- هر فایل → (فریم‌ها، ویژگی‌ها)
- برچسب هر فریم = برچسب فایل
- `groups` = نام فایل (برای GroupSplit)


In [None]:

def load_audio(path: str, sr: int) -> np.ndarray:
    y, _sr = librosa.load(path, sr=sr, mono=True)
    if np.max(np.abs(y)) > 0:
        y = y / np.max(np.abs(y))
    return y.astype(np.float32)

def build_frame_dataset(files: List[str], label: int):
    X_list, y_list, g_list = [], [], []
    for p in files:
        y = load_audio(p, SR)
        frames = frame_signal(y, FRAME_LEN, HOP_LEN)
        Xf = extract_features_for_frames(frames, SR)
        X_list.append(Xf)
        y_list.append(np.full((Xf.shape[0],), label, dtype=np.int32))
        g_list.extend([os.path.basename(p)] * Xf.shape[0])
    if not X_list:
        return np.zeros((0,6), dtype=np.float32), np.zeros((0,), dtype=np.int32), []
    return np.vstack(X_list), np.concatenate(y_list), g_list

X_s, y_s, g_s = build_frame_dataset(speech_files, LABEL_SPEECH)
X_m, y_m, g_m = build_frame_dataset(music_files,  LABEL_MUSIC)

X = np.vstack([X_s, X_m]) if len(X_s) and len(X_m) else (X_s if len(X_s) else X_m)
y = np.concatenate([y_s, y_m]) if len(y_s) and len(y_m) else (y_s if len(y_s) else y_m)
groups = g_s + g_m

print("X shape:", X.shape, "y:", y.shape, "unique labels:", np.unique(y))


## 5) آموزش Baseline SVM از صفر (PyTorch)

در مقاله از RBF استفاده شده است. اینجا هم از یک **SVM مبتنی بر hinge-loss** استفاده می‌کنیم که
کرنل RBF را با **Random Fourier Features** تقریب می‌زند (بدون sklearn).


In [None]:
# -----------------------------
# Utilities (no sklearn)
# -----------------------------
def group_train_test_split(
    X: np.ndarray,
    y: np.ndarray,
    groups: List[str],
    test_size: float = 0.25,
    random_state: int = 42,
    max_tries: int = 200,
):
    g = np.asarray(groups)
    unique_groups = np.unique(g)
    if len(unique_groups) < 2:
        raise RuntimeError("تعداد فایل‌ها برای group split کافی نیست.")

    n_test = int(round(test_size * len(unique_groups)))
    n_test = max(1, min(len(unique_groups) - 1, n_test))

    rng = np.random.default_rng(random_state)

    for _ in range(max_tries):
        perm = rng.permutation(unique_groups)
        test_groups = set(perm[:n_test].tolist())

        test_mask = np.array([gi in test_groups for gi in g], dtype=bool)
        train_idx = np.where(~test_mask)[0]
        test_idx = np.where(test_mask)[0]

        if len(train_idx) == 0 or len(test_idx) == 0:
            continue

        if len(np.unique(y[train_idx])) < 2:
            continue
        if len(np.unique(y[test_idx])) < 2:
            continue

        return train_idx, test_idx

    raise RuntimeError("نتوانستم split متعادلی پیدا کنم. تعداد فایل‌های speech/music را بیشتر کنید.")


def accuracy_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if y_true.size == 0:
        return 0.0
    return float(np.mean(y_true == y_pred))


def confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    cm = np.zeros((len(labels), len(labels)), dtype=np.int64)
    for i, lt in enumerate(labels):
        for j, lp in enumerate(labels):
            cm[i, j] = int(np.sum((y_true == lt) & (y_pred == lp)))
    return cm


def classification_report(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    target_names: List[str],
    labels: Optional[List[int]] = None,
) -> str:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    if labels is None:
        labels = sorted(np.unique(np.concatenate([y_true, y_pred])).tolist())

    header = f"{'class':<16}{'precision':>10}{'recall':>10}{'f1-score':>10}{'support':>10}"
    lines = [header, "-" * len(header)]

    macro_p, macro_r, macro_f = 0.0, 0.0, 0.0
    total_support = 0

    for lbl, name in zip(labels, target_names):
        tp = np.sum((y_true == lbl) & (y_pred == lbl))
        fp = np.sum((y_true != lbl) & (y_pred == lbl))
        fn = np.sum((y_true == lbl) & (y_pred != lbl))
        support = int(np.sum(y_true == lbl))

        p = float(tp / (tp + fp + 1e-12))
        r = float(tp / (tp + fn + 1e-12))
        f1 = float(2 * p * r / (p + r + 1e-12))

        lines.append(f"{name:<16}{p:>10.3f}{r:>10.3f}{f1:>10.3f}{support:>10d}")

        macro_p += p
        macro_r += r
        macro_f += f1
        total_support += support

    n_cls = max(len(labels), 1)
    lines.append("-" * len(header))
    lines.append(f"{'macro avg':<16}{macro_p/n_cls:>10.3f}{macro_r/n_cls:>10.3f}{macro_f/n_cls:>10.3f}{total_support:>10d}")
    lines.append(f"{'accuracy':<16}{'':>10}{'':>10}{accuracy_score(y_true, y_pred):>10.3f}{total_support:>10d}")
    return "\n".join(lines)


# -----------------------------
# PyTorch SVM from scratch
# -----------------------------
class StandardScalerScratch:
    def __init__(self, eps: float = 1e-8):
        self.eps = eps
        self.mean_ = None
        self.scale_ = None

    def fit(self, X: np.ndarray):
        X = np.asarray(X, dtype=np.float32)
        self.mean_ = X.mean(axis=0, keepdims=True)
        self.scale_ = X.std(axis=0, keepdims=True)
        self.scale_[self.scale_ < self.eps] = 1.0
        return self

    def transform(self, X: np.ndarray) -> np.ndarray:
        X = np.asarray(X, dtype=np.float32)
        return (X - self.mean_) / self.scale_

    def fit_transform(self, X: np.ndarray) -> np.ndarray:
        return self.fit(X).transform(X)


class TorchRBFSVMScratch:
    def __init__(
        self,
        gamma: float = 0.01,
        C: float = 1.0,
        n_rff: int = 512,
        lr: float = 3e-3,
        epochs: int = 25,
        batch_size: int = 2048,
        random_state: int = 42,
        device: Optional[str] = None,
        verbose: bool = True,
    ):
        self.gamma = gamma
        self.C = C
        self.n_rff = n_rff
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.random_state = random_state
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.verbose = verbose

        self.rff_w_ = None
        self.rff_b_ = None
        self.rff_scale_ = np.sqrt(2.0 / self.n_rff)
        self.w_ = None
        self.b0_ = 0.0

    def _rff_numpy(self, X: np.ndarray) -> np.ndarray:
        return self.rff_scale_ * np.cos(X @ self.rff_w_ + self.rff_b_)

    def fit(self, X: np.ndarray, y: np.ndarray):
        X = np.asarray(X, dtype=np.float32)
        y = np.asarray(y, dtype=np.float32)

        labels = set(np.unique(y).tolist())
        if not labels.issubset({-1.0, 1.0}):
            raise ValueError("y باید فقط شامل -1 و +1 باشد.")

        n_samples, n_features = X.shape
        if n_samples == 0:
            raise ValueError("X خالی است.")

        rng = np.random.default_rng(self.random_state)
        self.rff_w_ = rng.normal(
            loc=0.0,
            scale=np.sqrt(2.0 * self.gamma),
            size=(n_features, self.n_rff),
        ).astype(np.float32)
        self.rff_b_ = rng.uniform(0.0, 2.0 * np.pi, size=(self.n_rff,)).astype(np.float32)

        x_tensor = torch.from_numpy(X)
        y_tensor = torch.from_numpy(y)
        loader = DataLoader(
            TensorDataset(x_tensor, y_tensor),
            batch_size=min(self.batch_size, n_samples),
            shuffle=True,
        )

        rff_w_t = torch.from_numpy(self.rff_w_).to(self.device)
        rff_b_t = torch.from_numpy(self.rff_b_).to(self.device)

        w = torch.zeros(self.n_rff, dtype=torch.float32, device=self.device, requires_grad=True)
        b0 = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True)

        opt = torch.optim.Adam([w, b0], lr=self.lr)

        for ep in range(self.epochs):
            total_loss = 0.0
            for xb, yb in loader:
                xb = xb.to(self.device)
                yb = yb.to(self.device)

                z = self.rff_scale_ * torch.cos(xb @ rff_w_t + rff_b_t)
                scores = z @ w + b0

                hinge = torch.clamp(1.0 - yb * scores, min=0.0)
                loss = 0.5 * torch.sum(w * w) + self.C * torch.mean(hinge)

                opt.zero_grad()
                loss.backward()
                opt.step()

                total_loss += float(loss.detach().cpu())

            if self.verbose and ((ep + 1) % 5 == 0 or ep == 0 or ep == self.epochs - 1):
                print(f"Epoch {ep+1:02d}/{self.epochs} - loss: {total_loss / len(loader):.5f}")

        self.w_ = w.detach().cpu().numpy().astype(np.float32)
        self.b0_ = float(b0.detach().cpu().item())
        return self

    def decision_function(self, X: np.ndarray) -> np.ndarray:
        X = np.asarray(X, dtype=np.float32)
        z = self._rff_numpy(X)
        return z @ self.w_ + self.b0_

    def predict(self, X: np.ndarray) -> np.ndarray:
        scores = self.decision_function(X)
        return np.where(scores >= 0.0, LABEL_SPEECH, LABEL_MUSIC).astype(np.int32)


class TorchSVMPipeline:
    def __init__(self, **svm_kwargs):
        self.scaler = StandardScalerScratch()
        self.svm = TorchRBFSVMScratch(**svm_kwargs)

    def fit(self, X: np.ndarray, y: np.ndarray):
        Xs = self.scaler.fit_transform(X)
        self.svm.fit(Xs, y)
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        Xs = self.scaler.transform(X)
        return self.svm.predict(Xs)


# -----------------------------
# Train / Evaluate baseline
# -----------------------------
if len(np.unique(y)) < 2 or len(y) < 10:
    raise RuntimeError(f"دیتای کافی ندارید. حداقل چند فایل speech و music داخل '{DATA_DIR}/' بگذارید.")

train_idx, test_idx = group_train_test_split(X, y, groups=groups, test_size=0.25, random_state=42)

X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

baseline = TorchSVMPipeline(
    gamma=0.01,
    C=1.0,
    n_rff=512,
    lr=3e-3,
    epochs=25,
    batch_size=2048,
    random_state=42,
    verbose=True,
)

t0 = time.perf_counter()
baseline.fit(X_train, y_train)
t1 = time.perf_counter()

y_pred = baseline.predict(X_test)

print("Train time (s):", round(t1 - t0, 3))
print("Baseline accuracy:", accuracy_score(y_test, y_pred))
print("Confusion matrix [speech(+1), music(-1)]:")
print(confusion_matrix(y_test, y_pred, labels=[LABEL_SPEECH, LABEL_MUSIC]))
print(classification_report(
    y_test,
    y_pred,
    labels=[LABEL_SPEECH, LABEL_MUSIC],
    target_names=["speech(+1)", "music(-1)"]
))



## 6) Filtering Mechanism (Hierarchical)

اگر شرط ساده برقرار بود → music(-1) و SVM اجرا نمی‌شود.  
در غیر این صورت → SVM.


In [None]:

@dataclass
class FilterParams:
    thr_cpr: float
    thr_cM: float
    op: str  # "OR" or "AND"

def filter_predict_music_only(X_feat: np.ndarray, params: FilterParams) -> np.ndarray:
    cpr = X_feat[:, 4]
    cM  = X_feat[:, 5]
    cond1 = cpr > params.thr_cpr
    cond2 = cM  > params.thr_cM
    op = params.op.upper()
    if op == "OR":
        return cond1 | cond2
    if op == "AND":
        return cond1 & cond2
    raise ValueError("op باید OR یا AND باشد")

def hierarchical_predict(X_feat: np.ndarray, svm_model, params: FilterParams):
    filt_mask = filter_predict_music_only(X_feat, params)
    y_hat = np.zeros((X_feat.shape[0],), dtype=np.int32)
    y_hat[filt_mask] = LABEL_MUSIC

    idx = np.where(~filt_mask)[0]
    if len(idx):
        y_hat[idx] = svm_model.predict(X_feat[idx])

    stats = {
        "filtered_ratio": float(np.mean(filt_mask)),
        "svm_calls_ratio": float(np.mean(~filt_mask)),
    }
    return y_hat, stats

# مقدار شروع (برای دیتای شما شاید نیاز به tune داشته باشد)
fparams = FilterParams(thr_cpr=2.0, thr_cM=150.0, op="OR")

y_h_filt, st_filt = hierarchical_predict(X_test, baseline, fparams)
print("Hierarchical accuracy:", accuracy_score(y_test, y_h_filt))
print("Stats:", st_filt)



## 7) Skipping Mechanism (Inter-frame correlation)

نسخه‌ی مبتنی بر **previous classifications** (ساده و نزدیک به مقاله).  
Skipping روی یک stream معنی دارد، پس اینجا **برای هر فایل جدا** اجرا می‌کنیم.


In [None]:

@dataclass
class SkipParams:
    N_prev: int = 4
    thr_speech: int = 4
    thr_music: int = 4
    num_skip: int = 8

def skipping_predict_for_sequence(X_seq: np.ndarray, svm_model, sp: SkipParams):
    n = X_seq.shape[0]
    y_hat = np.zeros((n,), dtype=np.int32)

    history: List[int] = []
    nskip = 0
    cprev = None

    svm_calls = 0
    skipped = 0

    for t in range(n):
        if nskip > 0 and cprev is not None:
            y_hat[t] = cprev
            nskip -= 1
            skipped += 1
            continue

        if len(history) >= sp.N_prev:
            lastN = history[-sp.N_prev:]
            speech_count = sum(1 for c in lastN if c == LABEL_SPEECH)
            music_count  = sum(1 for c in lastN if c == LABEL_MUSIC)

            if speech_count >= sp.thr_speech:
                cprev = LABEL_SPEECH
                nskip = sp.num_skip
            elif music_count >= sp.thr_music:
                cprev = LABEL_MUSIC
                nskip = sp.num_skip

        y_hat[t] = svm_model.predict(X_seq[t:t+1])[0]
        svm_calls += 1
        history.append(int(y_hat[t]))

    stats = {"skipped_ratio": skipped / n, "svm_calls_ratio": svm_calls / n}
    return y_hat, stats



### اجرای skipping روی تست (file-by-file)


In [None]:

def predict_on_files_with_skipping(files: List[str], label: int, svm_model, sp: SkipParams):
    y_all, yhat_all, stats_list = [], [], []
    for p in files:
        ysig = load_audio(p, SR)
        frames = frame_signal(ysig, FRAME_LEN, HOP_LEN)
        Xf = extract_features_for_frames(frames, SR)
        y_true = np.full((Xf.shape[0],), label, dtype=np.int32)

        y_hat, st = skipping_predict_for_sequence(Xf, svm_model, sp)

        y_all.append(y_true)
        yhat_all.append(y_hat)
        stats_list.append(st)

    if not y_all:
        return np.array([], dtype=np.int32), np.array([], dtype=np.int32), {}

    y_all = np.concatenate(y_all)
    yhat_all = np.concatenate(yhat_all)
    stats = {
        "avg_skipped_ratio": float(np.mean([s["skipped_ratio"] for s in stats_list])),
        "avg_svm_calls_ratio": float(np.mean([s["svm_calls_ratio"] for s in stats_list])),
    }
    return y_all, yhat_all, stats

sp = SkipParams(N_prev=4, thr_speech=4, thr_music=4, num_skip=8)

test_groups = set([groups[i] for i in test_idx])
speech_test_files = [p for p in speech_files if os.path.basename(p) in test_groups]
music_test_files  = [p for p in music_files  if os.path.basename(p) in test_groups]

y_ts, yhat_ts, st_s = predict_on_files_with_skipping(speech_test_files, LABEL_SPEECH, baseline, sp)
y_tm, yhat_tm, st_m = predict_on_files_with_skipping(music_test_files,  LABEL_MUSIC,  baseline, sp)

y_true_skip = np.concatenate([y_ts, y_tm]) if len(y_ts) and len(y_tm) else (y_ts if len(y_ts) else y_tm)
y_pred_skip = np.concatenate([yhat_ts, yhat_tm]) if len(yhat_ts) and len(yhat_tm) else (yhat_ts if len(yhat_ts) else yhat_tm)

print("Skipping accuracy:", accuracy_score(y_true_skip, y_pred_skip))
print("Skipping stats (speech):", st_s)
print("Skipping stats (music) :", st_m)



## 8) Combined (Filtering → Skipping)

Filtering اول اجرا می‌شود؛ اگر فریم توسط فیلتر به عنوان music تشخیص داده شد، SVM اجرا نمی‌شود.  
در غیر این صورت، اگر skipping فعال باشد، SVM هم اجرا نمی‌شود.


In [None]:

@dataclass
class CombinedParams:
    filt: FilterParams
    skip: SkipParams

def combined_predict_for_sequence(X_seq: np.ndarray, svm_model, cp: CombinedParams):
    n = X_seq.shape[0]
    y_hat = np.zeros((n,), dtype=np.int32)

    history: List[int] = []
    nskip = 0
    cprev = None

    svm_calls = 0
    filtered = 0
    skipped = 0

    for t in range(n):
        filt_now = filter_predict_music_only(X_seq[t:t+1], cp.filt)[0]
        if filt_now:
            y_hat[t] = LABEL_MUSIC
            filtered += 1
            if nskip > 0:
                if cprev == LABEL_MUSIC:
                    nskip = max(nskip - 1, 0)
                else:
                    nskip = 0
                    cprev = None
            continue

        if nskip > 0 and cprev is not None:
            y_hat[t] = cprev
            nskip -= 1
            skipped += 1
            continue

        if len(history) >= cp.skip.N_prev:
            lastN = history[-cp.skip.N_prev:]
            speech_count = sum(1 for c in lastN if c == LABEL_SPEECH)
            music_count  = sum(1 for c in lastN if c == LABEL_MUSIC)

            if speech_count >= cp.skip.thr_speech:
                cprev = LABEL_SPEECH
                nskip = cp.skip.num_skip
            elif music_count >= cp.skip.thr_music:
                cprev = LABEL_MUSIC
                nskip = cp.skip.num_skip

        y_hat[t] = svm_model.predict(X_seq[t:t+1])[0]
        svm_calls += 1
        history.append(int(y_hat[t]))

    stats = {
        "filtered_ratio": filtered / n,
        "skipped_ratio": skipped / n,
        "svm_calls_ratio": svm_calls / n,
        "total_saved_ratio": (filtered + skipped) / n,
    }
    return y_hat, stats

def predict_on_files_combined(files: List[str], label: int, svm_model, cp: CombinedParams):
    y_all, yhat_all, stats_list = [], [], []
    for p in files:
        ysig = load_audio(p, SR)
        frames = frame_signal(ysig, FRAME_LEN, HOP_LEN)
        Xf = extract_features_for_frames(frames, SR)
        y_true = np.full((Xf.shape[0],), label, dtype=np.int32)

        y_hat, st = combined_predict_for_sequence(Xf, svm_model, cp)
        y_all.append(y_true)
        yhat_all.append(y_hat)
        stats_list.append(st)

    if not y_all:
        return np.array([], dtype=np.int32), np.array([], dtype=np.int32), {}

    y_all = np.concatenate(y_all)
    yhat_all = np.concatenate(yhat_all)
    keys = stats_list[0].keys()
    stats = {k: float(np.mean([s[k] for s in stats_list])) for k in keys}
    return y_all, yhat_all, stats

cp = CombinedParams(
    filt=FilterParams(thr_cpr=2.0, thr_cM=150.0, op="OR"),
    skip=SkipParams(N_prev=4, thr_speech=4, thr_music=4, num_skip=8),
)

y_cs, yhat_cs, st_cs = predict_on_files_combined(speech_test_files, LABEL_SPEECH, baseline, cp)
y_cm, yhat_cm, st_cm = predict_on_files_combined(music_test_files,  LABEL_MUSIC,  baseline, cp)

y_true_c = np.concatenate([y_cs, y_cm]) if len(y_cs) and len(y_cm) else (y_cs if len(y_cs) else y_cm)
y_pred_c = np.concatenate([yhat_cs, yhat_cm]) if len(yhat_cs) and len(yhat_cm) else (yhat_cs if len(yhat_cs) else yhat_cm)

print("Combined accuracy:", accuracy_score(y_true_c, y_pred_c))
print("Combined stats (speech):", st_cs)
print("Combined stats (music) :", st_cm)



## 9) جمع‌بندی سریع


In [None]:

def summarize(name: str, y_true: np.ndarray, y_pred: np.ndarray, extra: Optional[dict]=None):
    print("="*80)
    print(name)
    print("Accuracy:", accuracy_score(y_true, y_pred))
    print("Confusion matrix [speech(+1), music(-1)]:")
    print(confusion_matrix(y_true, y_pred, labels=[LABEL_SPEECH, LABEL_MUSIC]))
    if extra:
        print("Stats:", extra)

summarize("Baseline (frame-level)", y_test, y_pred)
summarize("Filtering-only (hierarchical)", y_test, y_h_filt, st_filt)
summarize("Skipping-only (file-by-file)", y_true_skip, y_pred_skip, {"speech": st_s, "music": st_m})
summarize("Combined (Filtering→Skipping, file-by-file)", y_true_c, y_pred_c, {"speech": st_cs, "music": st_cm})



## 10) (اختیاری) تیون کردن thresholdها برای Filtering
چون مقیاس کانترها در نسخه‌ی تقریبی با مقاله یکسان نیست، این grid-search ساده کمک می‌کند.


In [None]:

def grid_search_filter(X_eval, y_eval, svm_model,
                       thr_cpr_list, thr_cM_list, ops=("OR","AND"),
                       min_acc: float = 0.98):
    best = None
    for op in ops:
        for a in thr_cpr_list:
            for b in thr_cM_list:
                fp = FilterParams(thr_cpr=a, thr_cM=b, op=op)
                yhat, st = hierarchical_predict(X_eval, svm_model, fp)
                acc = accuracy_score(y_eval, yhat)
                if acc >= min_acc:
                    score = st["filtered_ratio"]
                    if best is None or score > best["filtered_ratio"]:
                        best = {"params": fp, "acc": acc, **st}
    return best

thr_cpr_list = [0.5, 1.0, 2.0, 3.0, 4.0]
thr_cM_list  = [50, 100, 150, 200, 250, 300]

best = grid_search_filter(X_test, y_test, baseline, thr_cpr_list, thr_cM_list, min_acc=0.98)
print(best)



---

## Done ✅

برای نزدیک‌تر شدن به مقاله:
1) ویژگی‌های دقیق SMV را از codec استخراج کن و جایگزین `extract_features_for_frames` کن  
2) thresholdها و پارامترهای skipping را روی dataset خودت tune کن  
3) معیار سرعت را با شمارش `svm_calls_ratio` و زمان اجرا بسنج  
