In [None]:
import torch
import torch.nn.functional as F
from fastai.learner import Metric
import numpy as np
from pesq import pesq
from pystoi import stoi

In [None]:
%run model.ipynb

In [None]:
def pesq_metric(pred, targ):
    """PESQ metric - computed on CPU"""
    pred = pred.squeeze(1).detach().cpu().numpy()
    targ = targ.squeeze(1).detach().cpu().numpy()
    scores = []
    for i in range(len(pred)):
        try:
            score = pesq(16000, targ[i], pred[i], 'wb')
            scores.append(score)
        except Exception as e:
            scores.append(0.0)
    return np.mean(scores) if scores else 0.0

In [None]:
def stoi_metric(pred, targ):
    """STOI metric - computed on CPU"""
    pred = pred.squeeze(1).detach().cpu().numpy()
    targ = targ.squeeze(1).detach().cpu().numpy()
    scores = []
    for i in range(len(pred)):
        try:
            score = stoi(targ[i], pred[i], 16000, extended=False)
            scores.append(score)
        except Exception as e:
            scores.append(0.0)
    return np.mean(scores) if scores else 0.0

In [None]:
def evaluate_checkpoint(model_path, dls, channels=48, num_blocks=10, 
                        num_repeats=2, n_batches=20):
    """
    Compute PESQ and STOI on a saved model.
    Call this once after training completes, not during training.
    
    n_batches: number of validation batches to evaluate (set to None for all)
    """
    model = CausalDNoizeConvTasNet(
        channels=channels, 
        num_blocks=num_blocks,
        num_repeats=num_repeats 
    )
    
    import warnings
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        state = torch.load(model_path, map_location='cpu', weights_only=False)
    
    # SaveModelCallback saves full learner state, not just model state_dict
    # Handle both formats
    if 'model' in state:
        model.load_state_dict(state['model'])
    else:
        model.load_state_dict(state)
    
    model.eval()

    pesq_scores, stoi_scores = [], []

    for batch_idx, (xb, yb) in enumerate(dls.valid):
        if n_batches and batch_idx >= n_batches:
            break

        xb_cpu = xb.float().cpu()
        yb_cpu = yb.float().cpu()

        with torch.no_grad():
            pred = model(xb_cpu)

        pred_np = pred.squeeze(1).numpy()
        targ_np = yb_cpu.squeeze(1).numpy()

        for i in range(len(pred_np)):
            try:
                pesq_scores.append(pesq(16000, targ_np[i], pred_np[i], 'wb'))
            except:
                pass
            try:
                stoi_scores.append(stoi(targ_np[i], pred_np[i], 16000, extended=False))
            except:
                pass

    pesq_mean = np.mean(pesq_scores) if pesq_scores else 0.0
    stoi_mean = np.mean(stoi_scores) if stoi_scores else 0.0
    print(f"PESQ : {pesq_mean:.4f} (target > 2.5)")
    print(f"STOI : {stoi_mean:.4f} (target > 0.88)")
    return pesq_mean, stoi_mean

In [None]:
class SISNRMetric(Metric):
    """
    Positive SI-SNR in dB (higher is better).
    Computed entirely on device tensors - no CPU/RAM overhead.
    
    Target values:
      epoch  0: ~7-9 dB
      epoch 20: ~11-13 dB
      epoch 50: ~14-16 dB
      epoch 80: ~17-20 dB
    """
    def reset(self):
        self.total = torch.tensor(0.0)
        self.count = 0

    def accumulate(self, learn):
        pred = learn.pred.squeeze(1).detach().float().cpu()
        targ = learn.yb[0].squeeze(1).detach().float().cpu()

        targ_zm = targ - targ.mean(dim=-1, keepdim=True)
        pred_zm = pred - pred.mean(dim=-1, keepdim=True)

        eps = 1e-8
        alpha = (targ_zm * pred_zm).sum(-1, keepdim=True) / \
                (targ_zm.pow(2).sum(-1, keepdim=True) + eps)
        s = (alpha * targ_zm).pow(2).sum(-1)
        n = (pred_zm - alpha * targ_zm).pow(2).sum(-1)
        sisnr = 10 * torch.log10((s + eps) / (n + eps))

        self.total += sisnr.sum().cpu()
        self.count += pred.shape[0]

    @property
    def value(self):
        return (self.total / self.count).item() if self.count > 0 else None

    @property
    def name(self):
        return "sisnr_db"

In [None]:
class NoiseReductionPct(Metric):
    """
    Percentage of noise power removed.
    Negative = model is making audio worse than noisy input.
    Expected to be negative early in training, crossing 0 around epoch 15-25.
    """
    def reset(self):
        self.total = torch.tensor(0.0)
        self.count = 0

    def accumulate(self, learn):
        noisy = learn.xb[0].squeeze(1).detach().float().cpu()
        clean = learn.yb[0].squeeze(1).detach().float().cpu()
        pred  = learn.pred.squeeze(1).detach().float().cpu()

        original_noise_pwr  = (noisy - clean).pow(2).mean(-1)
        remaining_noise_pwr = (pred  - clean).pow(2).mean(-1)

        # clamp(-100, 100) not (0, 100) â€” show negative values
        pct = ((original_noise_pwr - remaining_noise_pwr) /
               (original_noise_pwr + 1e-8) * 100).clamp(-100, 100)

        self.total += pct.sum().cpu()
        self.count += pred.shape[0]

    @property
    def value(self):
        return (self.total / self.count).item() if self.count > 0 else None

    @property
    def name(self):
        return "noise_removed_%"