In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
"""
Loss_mag_wav(output,target)
    output : [B, 2(real,imag), F, T]
    target : [B, 2(real,imag), F, T]
assume both are in STFT domain

Wang, Zhong-Qiu, et al. "STFT-Domain Neural Speech Enhancement with Very Low Algorithmic Latency." arXiv preprint arXiv:2204.09911 (2022).
"""
def Loss_mag_wav(output,target):
    mag_output = torch.sqrt(torch.pow(output[:,0,:,:],2)+torch.pow(output[:,1,:,:],2))
    mag_target = torch.sqrt(torch.pow(target[:,0,:,:],2)+torch.pow(target[:,1,:,:],2))

    l_1_mag = torch.norm(torch.abs(mag_output - mag_target),p=1)

    wav_output = torch.istft(output[:,0,:,:]+output[:,1,:,:]*1j,n_fft=512)
    wav_target = torch.istft(target[:,0,:,:]+target[:,1,:,:]*1j,n_fft=512)
    l_1_wav = torch.norm(torch.abs(wav_output - wav_target),p=1)

    # mean
    loss =  (l_1_mag + l_1_wav) / (output.shape[0]*output.shape[-1])

    return loss

In [17]:
x = torch.rand(3,2,257,4)
y1 = x + torch.rand(3,2,257,4)*0.1
y2 = x + torch.rand(3,2,257,4)*0.01
y3 = x + torch.rand(3,2,257,4)*0.001


print(Loss_mag_wav(x,y1))
print(Loss_mag_wav(x,y2))
print(Loss_mag_wav(x,y3))
print(Loss_mag_wav(x,x))

tensor(17.0470)
tensor(1.6636)
tensor(0.1695)
tensor(0.)
