In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# custom SI-SNR loss (custom)
def si_snr_loss(estimate, target, eps=1e-8):
    target = target - target.mean(dim=-1, keepdim=True)
    estimate = estimate - estimate.mean(dim=-1, keepdim=True)
    alpha = (target * estimate).sum(dim=-1, keepdim=True) / (target.pow(2).sum(dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = estimate - target_scaled
    s = target_scaled.pow(2).sum(dim=-1)
    n = noise.pow(2).sum(dim=-1)
    si_snr = 10 * torch.log10((s + eps) / (n + eps))
    return -si_snr.mean()  # minimize negative SI-SNR

In [None]:
class CombinedLoss(nn.Module):
    def forward(self, pred, targ):
        sisnr = si_snr_loss(pred.squeeze(1), targ.squeeze(1))
        l1 = F.l1_loss(pred, targ)
        return sisnr + 0.3 * l1