In [1]:
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
def si_snr(x, s, eps=1e-8):
    """

    Args:
        x: Enhanced fo shape [B, T]
        s: Reference of shape [B, T]
        eps:

    Returns:
        si_snr: [B]
    """

    def l2norm(mat, keep_dim=False):
        return torch.norm(mat, dim=-1, keepdim=keep_dim)

    if x.shape != s.shape:
        raise RuntimeError(
            f"Dimension mismatch when calculate si_snr, {x.shape} vs {s.shape}"
        )

    x_zm = x - torch.mean(x, dim=-1, keepdim=True)
    s_zm = s - torch.mean(s, dim=-1, keepdim=True)

    t = (
        torch.sum(x_zm * s_zm, dim=-1, keepdim=True)
        * s_zm
        / (l2norm(s_zm, keep_dim=True) ** 2 + eps)
    )

    return -torch.mean(20 * torch.log10(l2norm(t) / (l2norm(x_zm - t) + eps)))

In [25]:
a = torch.rand(2, 16000)
b = torch.rand(2, 16000)

In [26]:
si_snr(a,b)

tensor(44.8756)

In [27]:
class SISNRLoss(torch.nn.Module):
    def __init__(self, EPS=1e-8) -> None:
        super().__init__()
        self.EPS = EPS

    def forward(self, input, target):
        if input.shape != target.shape:
            raise RuntimeError(
                f"Dimension mismatch when calculate si_snr, {input.shape} vs {target.shape}"
            )

        s_input = input - torch.mean(input, dim=-1, keepdim=True)
        s_target = target - torch.mean(target, dim=-1, keepdim=True)

        # <s, s'> / ||s||**2 * s
        pair_wise_dot = torch.sum(s_target * s_input, dim=-1, keepdim=True)
        s_target_norm = torch.sum(s_target**2, dim=-1, keepdim=True)
        pair_wise_proj = pair_wise_dot * s_target / s_target_norm

        e_noise = s_input - pair_wise_proj

        pair_wise_sdr = torch.sum(pair_wise_proj**2, dim=-1) / (
            torch.sum(e_noise**2, dim=-1) + self.EPS
        )
        return 10 * torch.log10(pair_wise_sdr + self.EPS)

In [28]:
torch.mean(SISNRLoss()(a,b))

tensor(-44.8727)