In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def si_snr_loss(estimate, target, eps=1e-8):
    target = target - target.mean(dim=-1, keepdim=True)
    estimate = estimate - estimate.mean(dim=-1, keepdim=True)
    alpha = (target * estimate).sum(dim=-1, keepdim=True) / (target.pow(2).sum(dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = estimate - target_scaled
    s = target_scaled.pow(2).sum(dim=-1)
    n = noise.pow(2).sum(dim=-1)
    si_snr = 10 * torch.log10((s + eps) / (n + eps))
    return -si_snr.mean()

In [None]:
class SpectralLoss(nn.Module):
    def __init__(self, n_fft=512, hop=128):
        super().__init__()
        self.n_fft = n_fft
        self.hop = hop

    def forward(self, pred, targ):
        pred = pred.squeeze(1).float().clamp(-1.0, 1.0)
        targ = targ.squeeze(1).float().clamp(-1.0, 1.0)

        window = torch.hann_window(self.n_fft)

        # Compute STFT on CPU â€” no detach, keeps gradient flow intact
        pred_stft = torch.view_as_real(
            torch.stft(pred.cpu(), self.n_fft, self.hop,
                      window=window, return_complex=True)
        )
        targ_stft = torch.view_as_real(
            torch.stft(targ.cpu(), self.n_fft, self.hop,
                      window=window, return_complex=True)
        )

        pred_mag = pred_stft.pow(2).sum(-1).clamp(min=1e-8).sqrt()
        targ_mag = targ_stft.pow(2).sum(-1).clamp(min=1e-8).sqrt()

        mag_loss = F.l1_loss(pred_mag, targ_mag)
        log_loss = F.l1_loss(
            torch.log(pred_mag + 1e-8),
            torch.log(targ_mag + 1e-8)
        )

        # Move result back to original device
        return (mag_loss + log_loss).to(pred.device)

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, debug=False, use_spectral=False):  # off by default
        super().__init__()
        self.spectral = SpectralLoss()
        self.debug = debug
        self.use_spectral = use_spectral

    def forward(self, pred, targ):
        sisnr = si_snr_loss(pred.squeeze(1), targ.squeeze(1))
        l1 = F.l1_loss(pred, targ)
        
        if self.debug:
            print(f"  sisnr={sisnr.item():.4f}  l1={l1.item():.4f}")
        
        if self.use_spectral:
            spec = self.spectral(pred, targ)
            if self.debug:
                print(f"  spec={spec.item():.4f}")
            return sisnr + 0.3 * l1 + 0.1 * spec
        
        return sisnr + 0.3 * l1