In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def si_snr_loss(estimate, target, eps=1e-8):
    target = target - target.mean(dim=-1, keepdim=True)
    estimate = estimate - estimate.mean(dim=-1, keepdim=True)
    alpha = (target * estimate).sum(dim=-1, keepdim=True) / (target.pow(2).sum(dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = estimate - target_scaled
    s = target_scaled.pow(2).sum(dim=-1)
    n = noise.pow(2).sum(dim=-1)
    si_snr = 10 * torch.log10((s + eps) / (n + eps))
    return -si_snr.mean()


In [None]:
class MultiResolutionSpectralLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.resolutions = [(256, 64), (512, 128), (1024, 256)]

    def _stft_mag(self, x, n_fft, hop):
        window = torch.hann_window(n_fft)
        stft = torch.view_as_real(
            torch.stft(x.cpu(), n_fft, hop, window=window, return_complex=True)
        )
        return stft.pow(2).sum(-1).clamp(min=1e-8).sqrt()

    def forward(self, pred, targ):
        pred = pred.squeeze(1).float().clamp(-1, 1)
        targ = targ.squeeze(1).float().clamp(-1, 1)
        total = torch.tensor(0.0)
        for n_fft, hop in self.resolutions:
            pred_mag = self._stft_mag(pred, n_fft, hop)
            targ_mag = self._stft_mag(targ, n_fft, hop)
            total += F.l1_loss(pred_mag, targ_mag)
            total += F.l1_loss(
                torch.log(pred_mag + 1e-8),
                torch.log(targ_mag + 1e-8)
            )
        return (total / len(self.resolutions)).to(pred.device)

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, debug=False, use_spectral=False):
        super().__init__()
        self.spectral = MultiResolutionSpectralLoss()
        self.debug = debug
        self.use_spectral = use_spectral

    def forward(self, pred, targ):
        sisnr = si_snr_loss(pred.squeeze(1), targ.squeeze(1))
        l1    = F.l1_loss(pred, targ)

        # Squeeze, move to CPU float â€” matches diagnostic and metrics exactly
        pred_cpu = pred.squeeze(1).detach().float().cpu()
        targ_cpu = targ.squeeze(1).detach().float().cpu()

        pred_rms  = pred_cpu.pow(2).mean(dim=-1).sqrt()
        targ_rms  = targ_cpu.pow(2).mean(dim=-1).sqrt()
        ratio     = pred_rms / (targ_rms + 1e-8)
        power_reg = (ratio - 1.0).abs().mean().to(pred.device) # move back to pred.device

        if self.use_spectral:
            spec = self.spectral(pred, targ)
            if self.debug:
                print(f"  sisnr={sisnr.item():.4f}  l1={l1.item():.4f}  "
                    f"spec={spec.item():.4f}  power_reg={power_reg.item():.6f}  "
                    f"ratio_mean={ratio.mean().item():.4f}")
            return sisnr + 0.3 * l1 + 0.1 * spec + 5 * power_reg

        if self.debug:
            print(f"  sisnr={sisnr.item():.4f}  l1={l1.item():.4f}  "
                f"power_reg={power_reg.item():.6f}  "
                f"ratio_mean={ratio.mean().item():.4f}")

        return sisnr + 0.3 * l1 + 5 * power_reg
