In [4]:
import torch
import torch.nn as nn

EPS = 1e-13

class SASNRLossSegment(nn.Module):
    def __init__(self, reduction=torch.mean):
        super(SASNRLossSegment, self).__init__()
        self.reduction = reduction

    def forward(self, output, target, out_dict=True):

        # (12)
        denom_s = torch.sum(output*target,dim=-1)
        numer_s = torch.norm(target, dim=-1) * torch.norm(target, dim=-1)
        #scale = (numer_s/denom_s)
        scale = denom_s/numer_s
        s_target = torch.unsqueeze(scale,-1) * target
        print(f"{scale} {numer_s.shape} {denom_s.shape} {s_target.shape}")

        # (13)
        e_noise = output - s_target
        print(f"{e_noise.shape}")
 
        (14)
        # Compute norms
        norm_s_target_squared = torch.sum(s_target ** 2, dim=-1)  
        norm_e_noise_squared = torch.sum(e_noise ** 2, dim=-1)  
        norm_predicted = torch.norm(output, dim=-1)  
        norm_target = torch.norm(target, dim=-1) 

        print(f"{norm_s_target_squared.shape} {norm_e_noise_squared.shape} {norm_predicted.shape} {norm_target.shape}")

        # Scaling factors
        scaling_factor = norm_predicted / norm_target
        min_factor = torch.minimum(scaling_factor, torch.ones_like(scaling_factor))
        max_factor = torch.maximum(scaling_factor, torch.ones_like(scaling_factor))
        scale_term = min_factor / max_factor


        # Compute SA-SNR loss
        ratio = norm_s_target_squared / (norm_e_noise_squared + EPS)  
        loss_per_element = -10 * torch.log10(ratio * scale_term + EPS)  

        loss = self.reduction(loss_per_element)
        return {"SASNRLossSegment": loss} if out_dict else loss

x = torch.rand(2,3,2048)
y = torch.rand(2,3,2048)

m = SASNRLossSegment()


l = m(x,y,out_dict=False)
print(l)

y = x+0.5* torch.rand(2,3,2048)
l = m(x,y,out_dict=False)
print(l)

y = x
l = m(x,y,out_dict=False)
print(l)

tensor([[0.7164, 0.7461, 0.7467],
        [0.7423, 0.7301, 0.7510]]) torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3, 2048])
torch.Size([2, 3, 2048])
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
tensor(-0.9634)
tensor([[0.6828, 0.6809, 0.6882],
        [0.6854, 0.6885, 0.6849]]) torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3, 2048])
torch.Size([2, 3, 2048])
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
tensor(-10.7479)
tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]]) torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3, 2048])
torch.Size([2, 3, 2048])
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
tensor(-140.8473)


In [2]:
class CosSDRLossSegment(nn.Module):
    """
    It's a cosine similarity between predicted and clean signal
        loss = - <y_true, y_pred> / (||y_true|| * ||y_pred||)
    This loss function is always bounded between -1 and 1
    Ref: https://openreview.net/pdf?id=SkeRTsAcYm
    Hyeong-Seok Choi et al., Phase-aware Speech Enhancement with Deep Complex U-Net,
    """
    def __init__(self, reduction=torch.mean):
        super(CosSDRLossSegment, self).__init__()
        self.reduction = reduction

    def forward(self, output, target, out_dict=True):
        num = torch.sum(target * output, dim=-1)
        den = torch.norm(target, dim=-1) * torch.norm(output, dim=-1)
        print(f"{num} {den} {num/den}")
        loss_per_element = -num / (den + EPS)
        loss = self.reduction(loss_per_element)
        return {"CosSDRLossSegment": loss} if out_dict else loss
x = torch.rand(2,3,2048)
y = torch.rand(2,3,2048)

m = CosSDRLossSegment()

z = m(x,y)
print(z)

y = x
z = m(x,y)
print(z)

tensor([[498.7735, 517.9832, 514.9672],
        [518.9438, 517.7992, 507.6986]]) tensor([[665.8411, 684.2076, 679.7012],
        [685.8348, 688.7520, 675.1812]]) tensor([[0.7491, 0.7571, 0.7576],
        [0.7567, 0.7518, 0.7519]])
{'CosSDRLossSegment': tensor(-0.7540)}
tensor([[674.1790, 678.7164, 677.8181],
        [701.0212, 691.0522, 677.8785]]) tensor([[674.1791, 678.7162, 677.8181],
        [701.0211, 691.0522, 677.8785]]) tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]])
{'CosSDRLossSegment': tensor(-1.)}
