In [None]:
import os, time
import csv
import pyroomacoustics as pra
import numpy as np
from pathlib import Path
from datetime import timedelta
import torchaudio
from fastai.callback.all import *
from fastai.optimizer import Lookahead, Adam

In [None]:
def lookahead_adamw(params, **kwargs):
    kwargs['decouple_wd'] = True
    return Lookahead(Adam(params, **kwargs), k=6, alpha=0.5)

In [None]:
class EpochTracker(Callback):
    """
    Tracks epoch count, lr_max and per-epoch metrics.
    Saves to disk after every epoch for clean resume after failure.
    
    Files:
      tracker_fname — epochs_done, lr_max (for resume)
      log_fname     — full per-epoch metrics CSV
    
    Order=60 ensures this runs after Recorder(order=50) so
    recorder.values[-1] is fully populated when after_epoch fires.
    """
    order = 60

    def __init__(self, lr_max=0, total_epochs=200,
                 tracker_fname='training_stats/epoch_tracker.txt',
                 log_fname='training_stats/training_log.csv'):
        self.saved_lr_max  = lr_max
        self.total_epochs  = total_epochs
        self.tracker_fname = tracker_fname
        self.log_fname     = log_fname
        self.epochs_done, self.saved_lr_max = self._load_tracker()
        self._init_log()

    def _load_tracker(self):
        try:
            with open(self.tracker_fname, 'r') as f:
                parts = f.read().strip().split(',')
                epochs = int(parts[0])
                lr_max = float(parts[1]) if len(parts) > 1 else self.saved_lr_max
                return epochs, lr_max
        except:
            return 0, self.saved_lr_max

    def _save_tracker(self):
        os.makedirs(os.path.dirname(self.tracker_fname), exist_ok=True)
        with open(self.tracker_fname, 'w') as f:
            f.write(f"{self.epochs_done},{self.saved_lr_max}")

    def _init_log(self):
        os.makedirs(os.path.dirname(self.log_fname), exist_ok=True)
        if not Path(self.log_fname).exists():
            with open(self.log_fname, 'w', newline='') as f:
                csv.writer(f).writerow([
                    'epoch',
                    'train_loss',
                    'valid_loss',
                    'sisnr_db',
                    'noise_removed_percentage',
                    'current_lr',
                    'lr_max',
                    'epoch_elapsed_time'
                ])

    def before_epoch(self):
        self._epoch_start = time.time()
        set_epoch_seed(self.epochs_done)

    def after_epoch(self):
        """
        Runs after Recorder.after_epoch (order=50) so recorder.values[-1]
        is fully populated with [train_loss, valid_loss, sisnr_db, noise_removed_%]
        """
        elapsed  = time.time() - self._epoch_start
        self.epochs_done += 1

        recorder = self.learn.recorder

        # recorder.values[-1] layout (confirmed from debug output):
        # [train_loss, valid_loss, sisnr_db, noise_removed_%]
        vals = recorder.values[-1] if recorder.values else []

        train_loss    = vals[0] if len(vals) > 0 else None
        valid_loss    = vals[1] if len(vals) > 1 else None
        sisnr_db      = vals[2] if len(vals) > 2 else None
        noise_removed = vals[3] if len(vals) > 3 else None

        current_lr = self.learn.opt.hypers[-1]['lr'] if self.learn.opt else None

        with open(self.log_fname, 'a', newline='') as f:
            csv.writer(f).writerow([
                self.epochs_done,
                f"{train_loss:.6f}"    if train_loss     is not None else '',
                f"{valid_loss:.6f}"    if valid_loss     is not None else '',
                f"{sisnr_db:.6f}"      if sisnr_db       is not None else '',
                f"{noise_removed:.6f}" if noise_removed  is not None else '',
                f"{current_lr}"        if current_lr     is not None else '',
                f"{self.saved_lr_max}",
                f"{timedelta(seconds=elapsed)}"
            ])

        self._save_tracker()

In [None]:
class SISNRDiagnostic(Callback):
    """
    Computes manual SI-SNR on validation batch after each epoch
    and writes to file for monitoring without interrupting training.
    """
    def __init__(self, fname='training_stats/sisnr_diagnostic.csv', every_n=1):
        self.fname   = fname
        self.every_n = every_n
        self._init_file()

    def _init_file(self):
        os.makedirs(os.path.dirname(self.fname), exist_ok=True)
        if not Path(self.fname).exists():
            with open(self.fname, 'w', newline='') as f:
                csv.writer(f).writerow([
                    'epoch',
                    'sisnr_sample_0', 'sisnr_sample_1',
                    'sisnr_sample_2', 'sisnr_sample_3',
                    'sisnr_mean',
                    'pred_min', 'pred_max',
                    'pred_power_mean', 'targ_power_mean',
                    'power_ratio'
                ])

    def after_epoch(self):
        if self.epoch % self.every_n != 0:
            return
        try:
            self.learn.model.eval()
            xb, yb = self.learn.dls.valid.one_batch()

            with torch.no_grad():
                pred = self.learn.model(xb)

            pred_sq = pred.squeeze(1).detach().float().cpu()
            targ_sq = yb.squeeze(1).detach().float().cpu()

            targ_zm = targ_sq - targ_sq.mean(dim=-1, keepdim=True)
            pred_zm = pred_sq - pred_sq.mean(dim=-1, keepdim=True)

            eps = 1e-8
            alpha = (targ_zm * pred_zm).sum(-1, keepdim=True) / \
                    (targ_zm.pow(2).sum(-1, keepdim=True) + eps)
            s = (alpha * targ_zm).pow(2).sum(-1)
            n = (pred_zm - alpha * targ_zm).pow(2).sum(-1)
            sisnr = 10 * torch.log10((s + eps) / (n + eps))

            # Raw RMS power ratio — consistent with power_reg in loss
            # raw signal instead of zero-mean so it matches NoiseReductionPct
            pred_rms = pred_sq.pow(2).mean(dim=-1).sqrt()  # per sample
            targ_rms = targ_sq.pow(2).mean(dim=-1).sqrt()  # per sample
            per_sample_ratio = (pred_rms / (targ_rms + eps)).tolist()
            mean_ratio = (pred_rms / (targ_rms + eps)).mean().item()

            # survives resume correctly
            epoch_num = len(self.learn.recorder.values)

            samples = sisnr.tolist()
            while len(samples) < 4:
                samples.append('')

            with open(self.fname, 'a', newline='') as f:
                csv.writer(f).writerow([
                    epoch_num,
                    f"{samples[0]:.4f}" if samples[0] != '' else '',
                    f"{samples[1]:.4f}" if samples[1] != '' else '',
                    f"{samples[2]:.4f}" if samples[2] != '' else '',
                    f"{samples[3]:.4f}" if samples[3] != '' else '',
                    f"{sisnr.mean().item():.4f}",
                    f"{pred_sq.min().item():.4f}",
                    f"{pred_sq.max().item():.4f}",
                    f"{pred_rms.mean().item():.4f}",
                    f"{targ_rms.mean().item():.4f}",
                    f"{mean_ratio:.4f}"
                ])

        except Exception as e:
            with open(self.fname, 'a', newline='') as f:
                csv.writer(f).writerow([
                    getattr(self, 'epoch', '?') + 1,
                    f"ERROR: {e}",
                    '', '', '', '', '', '', '', '', ''
                ])
        finally:
            self.learn.model.train()

In [None]:
class PeriodicPESQSTOI(Callback):
    """
    Computes PESQ and STOI on CPU every N epochs on a small validation subset.
    Never crashes training — all errors are caught and logged.
    
    PESQ target: > 2.5
    STOI target: > 0.88
    """
    order = 70  # after EpochTracker(60) and Recorder(50)

    def __init__(self, fname='training_stats/pesq_stoi_log.csv',
                 every_n=10, n_batches=10):
        """
        every_n:   compute every N epochs
        n_batches: number of validation batches to evaluate
                   10 batches x bs=4 = 40 samples, ~2-3 min on CPU
        """
        self.fname = fname
        self.every_n = every_n
        self.n_batches = n_batches
        self._init_file()

    def _init_file(self):
        os.makedirs(os.path.dirname(self.fname), exist_ok=True)
        if not Path(self.fname).exists():
            with open(self.fname, 'w', newline='') as f:
                csv.writer(f).writerow([
                    'epoch',
                    'pesq_mean', 'pesq_min', 'pesq_max',
                    'stoi_mean', 'stoi_min', 'stoi_max',
                    'n_samples', 'elapsed_seconds'
                ])

    def after_epoch(self):
        # self.epoch is 0-based, so epoch 10 = self.epoch 9
        # Use recorder length to get true epoch number (survives resume)
        epoch_num = len(self.learn.recorder.values)

        if epoch_num % self.every_n != 0:
            return

        t_start = time.time()

        try:
            from pesq import pesq as pesq_fn
            from pystoi import stoi as stoi_fn
        except ImportError as e:
            print(f"\n[PeriodicPESQSTOI] Import error: {e} — skipping")
            return

        try:
            self.learn.model.eval()

            pesq_scores = []
            stoi_scores = []

            for batch_idx, (xb, yb) in enumerate(self.learn.dls.valid):
                if batch_idx >= self.n_batches:
                    break

                with torch.no_grad():
                    pred = self.learn.model(xb)

                # Move to CPU float — avoid DirectML precision issues
                pred_np = pred.squeeze(1).detach().float().cpu().numpy()
                targ_np = yb.squeeze(1).detach().float().cpu().numpy()

                for i in range(len(pred_np)):
                    # PESQ
                    try:
                        score = pesq_fn(16000, targ_np[i], pred_np[i], 'wb')
                        pesq_scores.append(score)
                    except Exception:
                        pass

                    # STOI
                    try:
                        score = stoi_fn(targ_np[i], pred_np[i],
                                        16000, extended=False)
                        stoi_scores.append(score)
                    except Exception:
                        pass

            elapsed = time.time() - t_start

            pesq_mean = float(np.mean(pesq_scores)) if pesq_scores else 0.0
            pesq_min  = float(np.min(pesq_scores))  if pesq_scores else 0.0
            pesq_max  = float(np.max(pesq_scores))  if pesq_scores else 0.0
            stoi_mean = float(np.mean(stoi_scores)) if stoi_scores else 0.0
            stoi_min  = float(np.min(stoi_scores))  if stoi_scores else 0.0
            stoi_max  = float(np.max(stoi_scores))  if stoi_scores else 0.0
            n_samples = len(pesq_scores)

            # Print summary
            print(f"\n[Epoch {epoch_num}] PESQ={pesq_mean:.4f} "
                  f"(min={pesq_min:.4f} max={pesq_max:.4f}) | "
                  f"STOI={stoi_mean:.4f} "
                  f"(min={stoi_min:.4f} max={stoi_max:.4f}) | "
                  f"n={n_samples} | {elapsed:.1f}s")

            # Log targets
            pesq_target = "✓" if pesq_mean > 2.5  else "✗"
            stoi_target = "✓" if stoi_mean > 0.88 else "✗"
            print(f"           PESQ target >2.5:  {pesq_target} | "
                  f"STOI target >0.88: {stoi_target}")

            with open(self.fname, 'a', newline='') as f:
                csv.writer(f).writerow([
                    epoch_num,
                    f"{pesq_mean:.4f}", f"{pesq_min:.4f}", f"{pesq_max:.4f}",
                    f"{stoi_mean:.4f}", f"{stoi_min:.4f}", f"{stoi_max:.4f}",
                    n_samples,
                    f"{elapsed:.1f}"
                ])

        except Exception as e:
            elapsed = time.time() - t_start
            print(f"\n[PeriodicPESQSTOI] Error at epoch {epoch_num}: {e}")
            with open(self.fname, 'a', newline='') as f:
                csv.writer(f).writerow([
                    epoch_num,
                    f"ERROR: {e}",
                    '', '', '', '', '', '', f"{elapsed:.1f}"
                ])

        finally:
            self.learn.model.train()

In [None]:
def set_epoch_seed(epoch):
    """Call before each epoch from main process"""
    EPOCH_SEED.value = epoch

In [None]:
def _load_raw_np(path, target_sr=16000):
    """Load audio file → mono float32 numpy [T]"""
    wave, sr = torchaudio.load(str(path))
    if wave.shape[0] > 1:
        wave = wave.mean(0, keepdim=True)
    if sr != target_sr:
        wave = torchaudio.transforms.Resample(sr, target_sr)(wave)
    return wave.squeeze(0).numpy().astype(np.float32)

In [None]:
def _get_strategy(path, epoch_seed):
    """
    Deterministic strategy selection per file per epoch.
    Same path + epoch → same strategy always.
    Changes every epoch for augmentation diversity.
    """
    h = abs(hash((str(path), int(epoch_seed)))) % 1000
    if h < 500:
        return 'denoise_only'   # 50%
    elif h < 800:
        return 'joint'          # 30%
    else:
        return 'deverb_only'    # 20%

In [None]:
def _get_reverb_rng(noisy_path, epoch_seed):
    """
    Deterministic RNG for reverb simulation.
    Same path + epoch → same room dimensions → same RIR for get_x and get_y.
    """
    seed = abs(hash((str(noisy_path), int(epoch_seed), 'rir'))) % (2**31)
    return np.random.RandomState(seed)

In [None]:
def add_reverb(wave_np, sr=16000, rt60_range=(0.2, 0.6), rng=None):
    """
    Add synthetic room reverb to float32 numpy [T].
    rng: numpy RandomState for deterministic output, None for random.
    Raises on failure — caller handles fallback.
    """
    if rng is None:
        rng = np.random.RandomState()

    orig_rms = np.sqrt(np.mean(wave_np ** 2)) + 1e-8
    rt60     = rng.uniform(*rt60_range)
    room_dim = [
        rng.uniform(3.0, 8.0),
        rng.uniform(3.0, 6.0),
        rng.uniform(2.5, 3.5)
    ]

    e_absorption, max_order = pra.inverse_sabine(rt60, room_dim)
    max_order = min(max_order, 12)

    room = pra.ShoeBox(
        room_dim, fs=sr,
        materials=pra.Material(e_absorption),
        max_order=max_order
    )

    def rand_pos(dims):
        return [rng.uniform(0.5, d - 0.5) for d in dims]

    room.add_source(rand_pos(room_dim), signal=wave_np)
    room.add_microphone(np.array(rand_pos(room_dim)).reshape(3, 1))
    room.simulate()

    out = room.mic_array.signals[0].astype(np.float32)
    out = out[:len(wave_np)]
    if len(out) < len(wave_np):
        out = np.pad(out, (0, len(wave_np) - len(out)))

    out_rms = np.sqrt(np.mean(out ** 2)) + 1e-8
    return out * (orig_rms / out_rms)