In [1]:
import torch
from df.enhance import init_df

import sys, os
from pathlib import Path

cwd = Path(os.getcwd()).resolve()
repo_root = cwd if (cwd / "metrics.py").exists() else cwd.parent
sys.path.insert(0, str(repo_root))
from util import ModelComparator

sys.path.insert(0, str(Path.cwd().resolve().parent))
from models.generator import LCTGenerator, LCTGeneratorConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from torchaudio.backend.common import AudioMetaData


In [2]:
DATA_ROOT = Path("D:/Projects/LCT-GAN/.data")
NOISE_DIR = DATA_ROOT / "noise"
CLEAN_DIR = DATA_ROOT / "clean"
NOISY_DIR = DATA_ROOT / "noisy"
ENH_DIR = DATA_ROOT / "20260105"

In [3]:
jit = torch.jit.load("D:/Projects/LCT-GAN/.model/FTFNet_scripted.pt",
                     map_location=device).eval()

ckpt = torch.load("D:/Projects/LCT-GAN/.model/myftf.pt",
                      map_location=device)

def extract_state_dict(obj):
    # common checkpoint layouts
    if isinstance(obj, dict):
        for k in ["state_dict", "model_state_dict", "model", "generator", "net"]:
            if k in obj and isinstance(obj[k], dict):
                return obj[k]
    return obj  # may already be a state_dict

sd = extract_state_dict(ckpt)

my_lct = LCTGenerator(LCTGeneratorConfig()).to(device)
missing, unexpected = my_lct.load_state_dict(sd, strict=False)
my_lct.eval()

df_model, df_state, _ = init_df()
df_sr = df_state.sr()

mc = ModelComparator(
    lct=jit,
    my_lct=my_lct,
    dfn=df_model,
    dfn_state=df_state,
    device=device,
    metrics_sr=16000,
)

[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on torch 2.8.0+cpu[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on host hashbrownmsi[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mGit commit: 8e66e3e, branch: main[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mLoading model settings of DeepFilterNet3[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mUsing DeepFilterNet3 model at C:\Users\12624\AppData\Local\DeepFilterNet\DeepFilterNet\Cache\DeepFilterNet3[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mInitializing model `deepfilternet3`[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mFound checkpoint C:\Users\12624\AppData\Local\DeepFilterNet\DeepFilterNet\Cache\DeepFilterNet3\checkpoints\model_120.ckpt.best with epoch 120[0m
[32m2026-01-09 00:18:07[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on 

In [4]:
import math
import numpy as np
import soundfile as sf


def to_mono(x):
    if x.ndim == 1:
        return x
    return x.mean(axis=1)


def match_length_by_tiling(noise, target_len: int):
    if len(noise) == target_len:
        return noise
    if len(noise) <= 0:
        return np.zeros((target_len, ), dtype=np.float32)
    if len(noise) < target_len:
        reps = int(math.ceil(target_len / len(noise)))
        return np.tile(noise, reps)[:target_len]
    return noise[:target_len]


def gain_db(x, db: float):
    g = 10.0**(db / 20.0)
    return x * g


def rms(x, eps: float = 1e-12) -> float:
    return float(np.sqrt(np.mean(x * x) + eps))


def clip_ratio(x, thr: float = 1.0) -> float:
    if x.size == 0:
        return 0.0
    return float(np.mean(np.abs(x) >= thr))


def hard_clip(x, thr: float = 0.95):
    return np.clip(x, -thr, thr)


def saturation_by_target_ratio(
    x: np.ndarray,
    clip_thr: float = 0.95,
    target_clip_ratio: float = 0.05,
    eps: float = 1e-12,
):
    """
    Amplify so ~target_clip_ratio samples exceed clip_thr, clip at ±clip_thr,
    then scale back down. Returns (y, applied_gain_lin, applied_gain_db).
    """
    a = np.abs(x).reshape(-1)
    if a.size == 0:
        return x, 1.0, 0.0

    q = float(np.quantile(a, 1.0 - target_clip_ratio))
    if q < eps:
        return x, 1.0, 0.0

    g = clip_thr / (q + eps)
    y = hard_clip(x * g, thr=clip_thr) / g
    g_db = 20.0 * math.log10(g + eps)
    return y, g, g_db


def build_tasks():
    tasks = []

    # 1) baseline
    tasks.append(
        dict(name="t1_baseline",
             clean_gain_db=0.0,
             noise_gain_db=0.0,
             sat=None))

    # 2) clean -5/-10 dB then add noise
    for cg in (-5.0, -10.0):
        tasks.append(
            dict(name=f"t2_clean{int(cg)}db",
                 clean_gain_db=cg,
                 noise_gain_db=0.0,
                 sat=None))

    # 3) noise +5/+10 dB then add clean
    for ng in (5.0, 10.0):
        tasks.append(
            dict(name=f"t3_noise+{int(ng)}db",
                 clean_gain_db=0.0,
                 noise_gain_db=ng,
                 sat=None))

    # 4) noise +5/+10 AND clean -5/-10
    for cg in (-5.0, -10.0):
        for ng in (5.0, 10.0):
            tasks.append(
                dict(name=f"t4_clean{int(cg)}db_noise+{int(ng)}db",
                     clean_gain_db=cg,
                     noise_gain_db=ng,
                     sat=None))

    # 5) clean very quiet
    tasks.append(
        dict(name="t5_clean-20db",
             clean_gain_db=-20.0,
             noise_gain_db=0.0,
             sat=None))

    # 6) saturation on the final noisy mixture (5% / 3% / 2% clipped at 0.95)
    # for p in (0.05, 0.03, 0.02):
    #     tasks.append(
    #         dict(name=f"t6_sat_thr0p95_clip{int(p*100)}pct",
    #              clean_gain_db=0.0,
    #              noise_gain_db=0.0,
    #              sat=dict(clip_thr=0.95, target_clip_ratio=p)))

    return tasks


TASKS = build_tasks()
len(TASKS), TASKS[:3]


(10,
 [{'name': 't1_baseline',
   'clean_gain_db': 0.0,
   'noise_gain_db': 0.0,
   'sat': None},
  {'name': 't2_clean-5db',
   'clean_gain_db': -5.0,
   'noise_gain_db': 0.0,
   'sat': None},
  {'name': 't2_clean-10db',
   'clean_gain_db': -10.0,
   'noise_gain_db': 0.0,
   'sat': None}])

In [5]:
import pandas as pd
import torchaudio
import soundfile as sf
import torch
from tqdm import tqdm


def resample_sf_to_16k(wav_np, sr, target_sr=16000):

    if sr == target_sr:
        return wav_np, sr

    wav = torch.from_numpy(wav_np)

    if wav.ndim == 1:  # [T]
        wav = wav.unsqueeze(0)  # [1, T]
        input_layout = "mono"
    elif wav.ndim == 2:  # [T, C] from soundfile
        wav = wav.transpose(0, 1)  # [C, T]
        input_layout = "tc"
    else:
        raise ValueError(f"Unexpected wav shape: {wav_np.shape}")

    wav = wav.to(torch.float32)

    wav_16k = torchaudio.functional.resample(wav,
                                             orig_freq=sr,
                                             new_freq=target_sr)

    wav_16k_np = wav_16k.cpu().numpy()
    if input_layout == "mono":
        wav_16k_np = wav_16k_np.squeeze(0)  # [T]
    else:
        wav_16k_np = wav_16k_np.transpose(1, 0)  # [T, C]

    return wav_16k_np, target_sr


def flatten_result(meta: dict, result: dict):
    rows = []
    for k in ["noisy", "ftfnet", "my_ftfnet", "dfn"]:
        if k not in result:
            continue
        row = dict(meta)
        row["model"] = k
        # copy all simple fields in result[k]
        for kk, vv in result[k].items():
            if isinstance(vv, (str, int, float, bool)) or vv is None:
                row[kk] = vv
        rows.append(row)
    return rows


CSV_PATH = NOISY_DIR / "experiment_metrics.csv"

clean_paths = sorted(CLEAN_DIR.glob("*.wav"))
noise_paths = sorted(NOISE_DIR.glob("*.wav"))

if not clean_paths:
    raise RuntimeError(f"No clean wav found under {CLEAN_DIR}")
if not noise_paths:
    raise RuntimeError(f"No noise wav found under {NOISE_DIR}")

all_rows = []

for clean_path in clean_paths:
    clean_stem = clean_path.stem
    clean_wav, sr_c = sf.read(str(clean_path), dtype="float32")
    clean_wav = to_mono(clean_wav).astype(np.float32)

    if sr_c != 16000:
        clean_wav, sr_c = resample_sf_to_16k(clean_wav, sr_c, target_sr=16000)

    for noise_path in tqdm(noise_paths):
        noise_stem = noise_path.stem
        noise_wav, sr_n = sf.read(str(noise_path), dtype="float32")
        noise_wav = to_mono(noise_wav).astype(np.float32)

        if sr_n != 16000:
            noise_wav, sr_n = resample_sf_to_16k(noise_wav,
                                                 sr_n,
                                                 target_sr=16000)

        noise_aligned = match_length_by_tiling(noise_wav, len(clean_wav))

        for spec in TASKS:
            task = spec["name"]

            c2 = gain_db(clean_wav, float(spec["clean_gain_db"]))
            n2 = gain_db(noise_aligned, float(spec["noise_gain_db"]))
            mix_pre = c2 + n2

            # optional saturation on the noisy mixture
            sat_gain_lin = 1.0
            sat_gain_db = 0.0
            mix_post = mix_pre
            if spec["sat"] is not None:
                mix_post, sat_gain_lin, sat_gain_db = saturation_by_target_ratio(
                    mix_pre,
                    clip_thr=float(spec["sat"]["clip_thr"]),
                    target_clip_ratio=float(spec["sat"]["target_clip_ratio"]),
                )

            mix_saved = np.clip(mix_post, -1.0, 1.0).astype(np.float32)

            task_noisy_dir = NOISY_DIR / task
            task_noisy_dir.mkdir(parents=True, exist_ok=True)

            noisy_path = task_noisy_dir / f"{task}_noisy_{clean_stem}_{noise_stem}.wav"
            sf.write(str(noisy_path), mix_saved, 16000)

            out_dir = ENH_DIR / task / f"{clean_stem}_{noise_stem}"
            result = mc.process_one_file(
                noisy_path=str(noisy_path),
                out_dir=str(out_dir),
                clean_path=str(clean_path),
            )

            # metadata row (shared across the model outputs)
            meta = {
                "task":
                task,
                "clean":
                clean_stem,
                "noise":
                noise_stem,
                "clean_path":
                str(clean_path),
                "noise_path":
                str(noise_path),
                "noisy_path":
                str(noisy_path),
                "clean_gain_db":
                float(spec["clean_gain_db"]),
                "noise_gain_db":
                float(spec["noise_gain_db"]),
                "sat_enabled":
                spec["sat"] is not None,
                "sat_clip_thr":
                float(spec["sat"]["clip_thr"]) if spec["sat"] else None,
                "sat_target_clip_ratio":
                float(spec["sat"]["target_clip_ratio"])
                if spec["sat"] else None,
                "sat_applied_gain_lin":
                float(sat_gain_lin) if spec["sat"] else None,
                "sat_applied_gain_db":
                float(sat_gain_db) if spec["sat"] else None,
                "rms_clean_scaled":
                rms(c2),
                "rms_noise_scaled":
                rms(n2),
                "rms_mix_pre":
                rms(mix_pre),
                "rms_mix_saved":
                rms(mix_saved),
                "peak_mix_pre":
                float(np.max(np.abs(mix_pre))) if mix_pre.size else 0.0,
                "peak_mix_saved":
                float(np.max(np.abs(mix_saved))) if mix_saved.size else 0.0,
                "clip_ratio_pre_at1":
                clip_ratio(mix_pre, thr=1.0),
                "clip_ratio_saved_at1":
                clip_ratio(mix_saved, thr=1.0),
            }

            meta["snr_db_scaled"] = 20.0 * math.log10(
                (meta["rms_clean_scaled"] + 1e-12) /
                (meta["rms_noise_scaled"] + 1e-12))

            all_rows.extend(flatten_result(meta, result))

            pd.DataFrame(all_rows).to_csv(CSV_PATH, index=False)


100%|██████████| 8/8 [10:29<00:00, 78.74s/it]


In [None]:
from __future__ import annotations

MODELS = ("ftfnet", "dfn", "my_ftfnet")
METRICS = ("si_sdr", "pesq", "stoi")


def keep_category_model_metrics(
        data,
        *,
        categories=None,  # e.g. {"impulse", "music"}; None = keep all
        models=MODELS,
        metrics=METRICS,
        drop_models_with_no_metrics: bool = True,  # useful for "roadside"
):
    cat_set = set(categories) if categories is not None else None

    out = {}
    for category, cat_dict in data.items():
        if cat_set is not None and category not in cat_set:
            continue

        kept_models = {}
        for model in models:
            if model not in cat_dict:
                continue

            model_dict = cat_dict[model] or {}
            kept_metrics = {m: model_dict.get(m) for m in metrics}

            if drop_models_with_no_metrics and all(
                    v is None for v in kept_metrics.values()):
                continue

            kept_models[model] = kept_metrics

        if kept_models:
            out[category] = kept_models

    return out


import pandas as pd

filtered = keep_category_model_metrics(results)

rows = [{
    "category": category,
    "model": model,
    **metrics
} for category, models_dict in filtered.items()
        for model, metrics in models_dict.items()]

df = pd.DataFrame(rows).sort_values(["category",
                                     "model"]).reset_index(drop=True)

In [15]:
df

Unnamed: 0,category,model,si_sdr,pesq,stoi
0,impulse,dfn,15.979971,2.830872,0.961708
1,impulse,ftfnet,20.484844,3.407346,0.980956
2,impulse,my_ftfnet,-9.13133,1.027396,0.602254
3,music,dfn,20.817392,3.575815,0.85111
4,music,ftfnet,18.288195,3.72775,0.866259
5,music,my_ftfnet,-15.402922,1.091536,0.445325
6,static1,dfn,7.943667,1.645568,0.90314
7,static1,ftfnet,7.524585,1.805562,0.913023
8,static1,my_ftfnet,-12.782477,1.051258,0.496695
9,static2,dfn,15.979971,2.830872,0.961708


In [None]:
# 测试：
# 1. 音频缩20/30db，inference
# 2. 音频声音小，噪声大
#    换两种噪声：water这个clean，换其他几个噪声进行测试
# 3. 人声再小一点，测试不同噪声
# 4. noisy信号再放大一点，做一点饱和，削顶
#    95% （5%/2%/3% 数据饱和），削顶之后再降回去（一部分饱和）
#    饱和能发现强大模型效果差别
# 5. wind的噪声 和 开车的噪声

# 训练一下，再去测试，voicebank 8/2