In [None]:
import torch
import torch.nn.functional as F
from fastai.learner import Metric
import numpy as np
from pesq import pesq
from pystoi import stoi

In [None]:
def pesq_metric(pred, targ):
    """PESQ metric - computed on CPU"""
    pred = pred.squeeze(1).detach().cpu().numpy()
    targ = targ.squeeze(1).detach().cpu().numpy()
    scores = []
    for i in range(len(pred)):
        try:
            score = pesq(16000, targ[i], pred[i], 'wb')
            scores.append(score)
        except Exception as e:
            scores.append(0.0)
    return np.mean(scores) if scores else 0.0

In [None]:
def stoi_metric(pred, targ):
    """STOI metric - computed on CPU"""
    pred = pred.squeeze(1).detach().cpu().numpy()
    targ = targ.squeeze(1).detach().cpu().numpy()
    scores = []
    for i in range(len(pred)):
        try:
            score = stoi(targ[i], pred[i], 16000, extended=False)
            scores.append(score)
        except Exception as e:
            scores.append(0.0)
    return np.mean(scores) if scores else 0.0

In [None]:
%run model.ipynb

In [None]:
def evaluate_checkpoint(model_path, dls, channels=128, num_blocks=4, n_batches=20):
    """
    Compute PESQ and STOI on a saved model.
    Call this once after training completes, not during training.
    
    n_batches: number of validation batches to evaluate (set to None for all)
    """

    model = CausalDNoizeConvTasNet(channels=channels, num_blocks=num_blocks)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()

    pesq_scores, stoi_scores = [], []
    dl = dls.valid

    for batch_idx, (xb, yb) in enumerate(dl):
        if n_batches and batch_idx >= n_batches:
            break

        xb_cpu = xb.cpu()
        yb_cpu = yb.cpu()

        with torch.no_grad():
            pred = model(xb_cpu)

        pred_np = pred.squeeze(1).numpy()
        targ_np = yb_cpu.squeeze(1).numpy()

        for i in range(len(pred_np)):
            try:
                pesq_scores.append(pesq(16000, targ_np[i], pred_np[i], 'wb'))
            except:
                pass
            try:
                stoi_scores.append(stoi(targ_np[i], pred_np[i], 16000, extended=False))
            except:
                pass

    print(f"PESQ : {np.mean(pesq_scores):.4f} (target > 2.5)")
    print(f"STOI : {np.mean(stoi_scores):.4f} (target > 0.88)")
    return np.mean(pesq_scores), np.mean(stoi_scores)

In [None]:
# class SignalQualityScore(Metric):
#     """
#     Converts SI-SNR to a percentage score (0-100%).
    
#     Intuition: 
#     - 0 dB SI-SNR = 50% quality
#     - 20 dB SI-SNR = 100% quality
#     - Negative SI-SNR = <50% quality
    
#     This feels like "accuracy" - higher is better, 100% is perfect.
#     """
#     def __init__(self, max_sisnr=20.0):
#         self.max_sisnr = max_sisnr  # Max expected SI-SNR
#         self.reset()
    
#     def reset(self):
#         self.total = 0.0
#         self.count = 0
    
#     def accumulate(self, learn):
#         pred = learn.pred.squeeze(1)
#         targ = learn.yb[0].squeeze(1)
        
#         # Calculate SI-SNR
#         target = targ - targ.mean(dim=-1, keepdim=True)
#         estimate = pred - pred.mean(dim=-1, keepdim=True)
        
#         eps = 1e-8
#         alpha = (target * estimate).sum(dim=-1, keepdim=True) / \
#                 (target.pow(2).sum(dim=-1, keepdim=True) + eps)
#         target_scaled = alpha * target
#         noise = estimate - target_scaled
        
#         s = target_scaled.pow(2).sum(dim=-1)
#         n = noise.pow(2).sum(dim=-1)
#         si_snr = 10 * torch.log10((s + eps) / (n + eps))
        
#         # Convert to percentage (0-100%)
#         # Map: [-inf, 0, max_sisnr, +inf] -> [0%, 50%, 100%, 100%]
#         quality = (si_snr / self.max_sisnr) * 50 + 50
#         quality = torch.clamp(quality, 0, 100)
        
#         bs = pred.shape[0]
#         self.total += quality.sum().detach().cpu()
#         self.count += bs
    
#     @property
#     def value(self):
#         return (self.total / self.count).item() if self.count > 0 else None
    
#     @property
#     def name(self):
#         return "quality_%"


In [None]:
# class SNRImprovement(Metric):
#     """
#     Measures how much SNR improved from input to output (as percentage).
    
#     Example:
#     - Input SNR: 5 dB
#     - Output SNR: 15 dB
#     - Improvement: (15-5)/5 * 100 = 200% improvement
    
#     This shows the "gain" from denoising, like accuracy improvement.
#     """
#     def reset(self):
#         self.total = 0.0
#         self.count = 0
    
#     def accumulate(self, learn):
#         noisy = learn.xb[0].squeeze(1)  # Input (noisy)
#         clean = learn.yb[0].squeeze(1)  # Target (clean)
#         pred = learn.pred.squeeze(1)     # Prediction (denoised)
        
#         # Calculate input SNR (noisy vs clean)
#         noise_input = noisy - clean
#         signal_power = (clean ** 2).mean(dim=-1)
#         noise_power_input = (noise_input ** 2).mean(dim=-1)
#         snr_input = 10 * torch.log10(signal_power / (noise_power_input + 1e-8) + 1e-8)
        
#         # Calculate output SNR (pred vs clean)
#         noise_output = pred - clean
#         noise_power_output = (noise_output ** 2).mean(dim=-1)
#         snr_output = 10 * torch.log10(signal_power / (noise_power_output + 1e-8) + 1e-8)
        
#         # Improvement in dB
#         improvement_db = snr_output - snr_input
        
#         # Convert to percentage (0 dB improvement = 0%, 10 dB = 100%)
#         improvement_pct = torch.clamp(improvement_db / 10.0 * 100, 0, 200)
        
#         bs = pred.shape[0]
#         self.total += improvement_pct.sum().detach().cpu()
#         self.count += bs
    
#     @property
#     def value(self):
#         return (self.total / self.count).item() if self.count > 0 else None
    
#     @property
#     def name(self):
#         return "snr_improv_%"

In [None]:
# class WaveformSimilarity(Metric):
#     """
#     Measures how similar the predicted waveform is to the target (0-100%).
    
#     Uses normalized correlation coefficient - feels like accuracy:
#     - 100% = perfect match
#     - 50% = random
#     - 0% = completely wrong
#     """
#     def reset(self):
#         self.total = 0.0
#         self.count = 0
    
#     def accumulate(self, learn):
#         pred = learn.pred.squeeze(1)
#         targ = learn.yb[0].squeeze(1)
        
#         # Normalize both signals
#         pred_norm = pred - pred.mean(dim=-1, keepdim=True)
#         targ_norm = targ - targ.mean(dim=-1, keepdim=True)
        
#         # Correlation coefficient
#         correlation = (pred_norm * targ_norm).sum(dim=-1)
#         pred_energy = torch.sqrt((pred_norm ** 2).sum(dim=-1) + 1e-8)
#         targ_energy = torch.sqrt((targ_norm ** 2).sum(dim=-1) + 1e-8)
        
#         similarity = correlation / (pred_energy * targ_energy + 1e-8)
        
#         # Convert to percentage (0-100%)
#         similarity_pct = (similarity + 1) / 2 * 100  # Map [-1,1] to [0,100]
#         similarity_pct = torch.clamp(similarity_pct, 0, 100)
        
#         bs = pred.shape[0]
#         self.total += similarity_pct.sum().detach().cpu()
#         self.count += bs
    
#     @property
#     def value(self):
#         return (self.total / self.count).item() if self.count > 0 else None
    
#     @property
#     def name(self):
#         return "similarity_%"


In [None]:
# class NoiseReductionScore(Metric):
#     """
#     Measures what percentage of noise was successfully removed.
    
#     - 100% = all noise removed (perfect)
#     - 50% = half the noise removed
#     - 0% = no noise removed
#     - Negative = made it worse
    
#     Most intuitive "accuracy-like" metric!
#     """
#     def reset(self):
#         self.total = 0.0
#         self.count = 0
    
#     def accumulate(self, learn):
#         noisy = learn.xb[0].squeeze(1)
#         clean = learn.yb[0].squeeze(1)
#         pred = learn.pred.squeeze(1)
        
#         # Original noise
#         original_noise = noisy - clean
#         original_noise_power = (original_noise ** 2).sum(dim=-1)
        
#         # Remaining noise after denoising
#         remaining_noise = pred - clean
#         remaining_noise_power = (remaining_noise ** 2).sum(dim=-1)
        
#         # Percentage of noise removed
#         noise_removed = (original_noise_power - remaining_noise_power) / (original_noise_power + 1e-8)
#         noise_removed_pct = noise_removed * 100
        
#         # Clamp to reasonable range
#         noise_removed_pct = torch.clamp(noise_removed_pct, -50, 100)
        
#         bs = pred.shape[0]
#         self.total += noise_removed_pct.sum().detach().cpu()
#         self.count += bs
    
#     @property
#     def value(self):
#         return (self.total / self.count).item() if self.count > 0 else None
    
#     @property
#     def name(self):
#         return "noise_removed_%"

In [None]:
# class DenoisingAccuracy(Metric):
#     """
#     Combines three metrics that feels like classification accuracy:
#     1. SI-SNR Quality (0-100%)
#     2. Noise Reduction (0-100%)
#     3. Correlation (0-100%)
    
#     Expected values:
#     - Loss -6 to -7: accuracy ~60-65%
#     - Loss -10 to -12: accuracy ~75-80%
#     - Loss -15 to -18: accuracy ~85-90%
#     - Loss > -20: accuracy ~90-95%
#     """
#     def reset(self):
#         self.total_quality = 0.0
#         self.total_noise_red = 0.0
#         self.total_corr = 0.0
#         self.count = 0
    
#     def accumulate(self, learn):
#         # Get tensors and move to CPU for calculation
#         noisy = learn.xb[0].squeeze(1).detach().cpu()
#         clean = learn.yb[0].squeeze(1).detach().cpu()
#         pred = learn.pred.squeeze(1).detach().cpu()
        
#         bs = pred.shape[0]
        
#         # Calculate for each item in batch
#         for i in range(bs):
#             noisy_i = noisy[i]
#             clean_i = clean[i]
#             pred_i = pred[i]
            
#             # 1. SI-SNR Quality Score (0-100%)
#             target = clean_i - clean_i.mean()
#             estimate = pred_i - pred_i.mean()
            
#             eps = 1e-8
#             alpha = torch.dot(target, estimate) / (torch.dot(target, target) + eps)
#             target_scaled = alpha * target
#             noise = estimate - target_scaled
            
#             s = torch.sum(target_scaled ** 2)
#             n = torch.sum(noise ** 2)
#             si_snr = 10 * torch.log10((s + eps) / (n + eps))
            
#             # Map SI-SNR to 0-100%
#             # -5 dB = 0%, 0 dB = 50%, 20 dB = 100%
#             quality = torch.clamp((si_snr + 5) / 25 * 100, 0, 100)
            
#             # 2. Noise Reduction Score (0-100%)
#             original_noise = noisy_i - clean_i
#             original_power = torch.mean(original_noise ** 2)
            
#             remaining_noise = pred_i - clean_i
#             remaining_power = torch.mean(remaining_noise ** 2)
            
#             noise_reduction = ((original_power - remaining_power) / (original_power + eps)) * 100
#             noise_reduction = torch.clamp(noise_reduction, 0, 100)
            
#             # 3. Correlation Score (0-100%)
#             pred_norm = pred_i - pred_i.mean()
#             clean_norm = clean_i - clean_i.mean()
            
#             correlation = torch.dot(pred_norm, clean_norm)
#             pred_energy = torch.sqrt(torch.sum(pred_norm ** 2) + eps)
#             clean_energy = torch.sqrt(torch.sum(clean_norm ** 2) + eps)
            
#             corr_coef = correlation / (pred_energy * clean_energy + eps)
#             corr_score = (corr_coef + 1) / 2 * 100  # Map [-1,1] to [0,100]
#             corr_score = torch.clamp(corr_score, 0, 100)
            
#             # Accumulate
#             self.total_quality += quality.item()
#             self.total_noise_red += noise_reduction.item()
#             self.total_corr += corr_score.item()
#             self.count += 1
    
#     @property
#     def value(self):
#         if self.count == 0:
#             return None
        
#         avg_quality = self.total_quality / self.count
#         avg_noise_red = self.total_noise_red / self.count
#         avg_corr = self.total_corr / self.count
        
#         # Weighted combination
#         accuracy = 0.4 * avg_quality + 0.3 * avg_noise_red + 0.3 * avg_corr
        
#         return accuracy
    
#     @property
#     def name(self):
#         return "accuracy_%"

In [None]:
class SISNRMetric(Metric):
    """
    Positive SI-SNR in dB (higher is better).
    Computed entirely on device tensors - no CPU/RAM overhead.
    
    Target values:
      epoch  0: ~7-9 dB
      epoch 20: ~11-13 dB
      epoch 50: ~14-16 dB
      epoch 80: ~17-20 dB
    """
    def reset(self):
        self.total = torch.tensor(0.0)
        self.count = 0

    def accumulate(self, learn):
        pred = learn.pred.squeeze(1).detach()
        targ = learn.yb[0].squeeze(1).detach()

        targ_zm = targ - targ.mean(dim=-1, keepdim=True)
        pred_zm = pred - pred.mean(dim=-1, keepdim=True)

        eps = 1e-8
        alpha = (targ_zm * pred_zm).sum(-1, keepdim=True) / \
                (targ_zm.pow(2).sum(-1, keepdim=True) + eps)
        s = (alpha * targ_zm).pow(2).sum(-1)
        n = (pred_zm - alpha * targ_zm).pow(2).sum(-1)
        sisnr = 10 * torch.log10((s + eps) / (n + eps))

        self.total += sisnr.sum().cpu()
        self.count += pred.shape[0]

    @property
    def value(self):
        return (self.total / self.count).item() if self.count > 0 else None

    @property
    def name(self):
        return "sisnr_db"

In [None]:
class NoiseReductionPct(Metric):
    """
    Percentage of noise power removed.
    Negative = model is making audio worse than noisy input.
    Expected to be negative early in training, crossing 0 around epoch 15-25.
    """
    def reset(self):
        self.total = torch.tensor(0.0)
        self.count = 0

    def accumulate(self, learn):
        noisy = learn.xb[0].squeeze(1).detach()
        clean = learn.yb[0].squeeze(1).detach()
        pred  = learn.pred.squeeze(1).detach()

        original_noise_pwr  = (noisy - clean).pow(2).mean(-1)
        remaining_noise_pwr = (pred  - clean).pow(2).mean(-1)

        # clamp(-100, 100) not (0, 100) â€” show negative values
        pct = ((original_noise_pwr - remaining_noise_pwr) /
               (original_noise_pwr + 1e-8) * 100).clamp(-100, 100)

        self.total += pct.sum().cpu()
        self.count += pred.shape[0]

    @property
    def value(self):
        return (self.total / self.count).item() if self.count > 0 else None

    @property
    def name(self):
        return "noise_removed_%"