In [2]:
import torch

In [3]:
def calc_sdr_torch(estimation, origin):
    """
    batch-wise SDR caculation for one audio file on pytorch Variables.
    estimation: (batch, nsample)
    origin: (batch, nsample)
    mask: optional, (batch, nsample), binary
    """
    
    origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8  # (batch, 1)
    
    scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power  # (batch, 1)

    est_true = scale * origin  # (batch, nsample)
    est_res = estimation - est_true  # (batch, nsample)
    
    true_power = torch.pow(est_true, 2).sum(1)
    res_power = torch.pow(est_res, 2).sum(1)
    
    return 10*torch.log10(true_power) - 10*torch.log10(res_power)  # (batch, 1)

In [4]:
t1 = torch.rand((1,64000))
t2 = torch.rand((1,64000))

In [5]:
calc_sdr_torch(t1, t2)

torch.Size([1, 1])
torch.Size([1, 1])


tensor([1.1156])

In [6]:
def si_snr(y_hat, y):
    
    y_power = torch.pow(y, 2).sum(-1, keepdim=False) + 1e-8
    
    scale_factor = y_hat@y.t()/y_power
    
    residual = y_hat.unsqueeze(1) - y
    residual_norms = torch.pow(residual).sum(-1, keepdim=False)
    
    temp = y_hat@y.t()/y_power
    
    s_target = y.unsqueeze(1) * temp.t().unsqueeze(-1)
    
    e_noise = torch.pow(y_hat.unsqueeze(1) - y, 2).sum(-1)
    
    s = torch.pow(s_target, 2).sum(-1)
    
    temp = 10*(torch.log10(s) - torch.log10(e_noise))
    temp = temp.sum(-1)
    
    return torch.max(temp)

In [3]:
y = torch.rand((2,4))
y_hat = torch.rand((2,4))

In [4]:
y, y_hat

(tensor([[0.0234, 0.8795, 0.9800, 0.1565],
         [0.2720, 0.2555, 0.4848, 0.8412]]),
 tensor([[0.8396, 0.1334, 0.3134, 0.9723],
         [0.0495, 0.5926, 0.2339, 0.2166]]))

In [67]:
y_power = torch.pow(y,2).sum(-1, keepdim=False)
y_power

tensor([1.7590, 1.0820])

In [66]:
y_hat@y.t()

tensor([[0.5963, 1.2324],
        [0.7854, 0.4605]])

In [70]:
y_hat@y.t()/1.082

tensor([[0.5511, 1.1390],
        [0.7259, 0.4256]])

In [73]:
temp = y_hat@y.t()/y_power
temp

tensor([[0.3390, 1.1390],
        [0.4465, 0.4256]])

In [85]:
temp, y

(tensor([[0.3390, 1.1390],
         [0.4465, 0.4256]]),
 tensor([[0.0234, 0.8795, 0.9800, 0.1565],
         [0.2720, 0.2555, 0.4848, 0.8412]]))

In [117]:
s_target = y.unsqueeze(1) * temp.t().unsqueeze(-1)
s_target

tensor([[[0.0079, 0.2982, 0.3322, 0.0530],
         [0.0105, 0.3927, 0.4376, 0.0699]],

        [[0.3098, 0.2910, 0.5522, 0.9582],
         [0.1158, 0.1087, 0.2063, 0.3580]]])

In [127]:
tempp = torch.log10(torch.pow(s_target, 2).sum(-1))

In [125]:
torch.pow(s_target[0][1], 2).sum(-1)

tensor(0.3507)

In [96]:
y[0] * 0.3390, y[0] * 0.4465

(tensor([0.0079, 0.2982, 0.3322, 0.0530]),
 tensor([0.0105, 0.3927, 0.4376, 0.0699]))

In [100]:
y_hat, y

(tensor([[0.8396, 0.1334, 0.3134, 0.9723],
         [0.0495, 0.5926, 0.2339, 0.2166]]),
 tensor([[0.0234, 0.8795, 0.9800, 0.1565],
         [0.2720, 0.2555, 0.4848, 0.8412]]))

In [102]:
y_hat.unsqueeze(1) - y

tensor([[[ 0.8162, -0.7461, -0.6666,  0.8158],
         [ 0.5676, -0.1221, -0.1714,  0.1311]],

        [[ 0.0261, -0.2870, -0.7461,  0.0601],
         [-0.2225,  0.3371, -0.2510, -0.6246]]])

In [128]:
tempp2 = torch.log10(torch.pow(y_hat.unsqueeze(1) - y, 2).sum(-1))

In [116]:
torch.pow((y_hat.unsqueeze(1) - y)[1][0], 2).sum(-1)

tensor(0.6433)

In [111]:
(y_hat.unsqueeze(1) - y)[0][1]

tensor([ 0.5676, -0.1221, -0.1714,  0.1311])

In [110]:
0.8162**2

0.66618244

In [113]:
0.5657**2

0.32001649

In [131]:
torch.max((10*(temp - tempp2)).sum(-1))

tensor(15.2612)

In [132]:
(10*(temp - tempp2)).sum(-1)

tensor([15.2612, 12.7388])