In [7]:
import os,glob
import torch
import matplotlib
import matplotlib.pylab as plt
import numpy as np
import librosa
import soundfile as sf
from matplotlib import cm
import torch

In [8]:
root_sample = '/home/data/kbh/Study_Loss/'

## Aux Funcs

In [9]:
def spectrogram(x) : 
    eps = 1e-13

    # dB scale
    tmp = np.sqrt(x[:,:,0]**2 + x[:,:,0]**2)
    tmp = np.clip(tmp,a_min=eps,a_max=None)
    tmp = 10*np.log10(tmp)
    tmp = tmp[::-1,:]

    im = plt.imshow(tmp, cmap=cm.jet, aspect='auto')
    plt.colorbar(im)
    plt.clim(-80,20)
    plt.show()

In [41]:
window = torch.hann_window(window_length=1024, periodic=True,
                               dtype=None, layout=torch.strided, device=None,
                               requires_grad=False)

def Load(root,category):
    n_fft = 1024
    sr = 16000
    length = 16*15

    batch_wav = None
    batch_spec = None
    
    list_path = [x for x in glob.glob(os.path.join(root,category,'*.pt'))]
    list_name = [x.split('/')[-1] for x in list_path]
    list_id   = [x.split('.')[0] for x in list_name]
    
    for id in list_id :
        tmp_spec = torch.load(root + '/' + category +'/'+id + '.pt')
        tmp_spec = torch.squeeze(tmp_spec)
        tmp_spec = tmp_spec[:,:length,:]
        tmp_wav = torch.istft(tmp_spec,n_fft=1024, hop_length = 256, 
                  window=window, center = True, normalized=False,onesided=True,
                  length=length*256)    
            
        tmp_spec = torch.unsqueeze(tmp_spec,0)
        tmp_wav  = torch.unsqueeze(tmp_wav,0)
            
        #print('------')
        #print(tmp_spec.shape)
        #print(tmp_wav.shape)   
            
        if batch_spec is None :
            batch_spec = tmp_spec
            batch_wav  = tmp_wav
        else :
            batch_spec = torch.cat((batch_spec,tmp_spec),0)
            batch_wav = torch.cat((batch_wav,tmp_wav),0)
        #print(batch_spec.shape)
        #print(batch_wav.shape)
            
    return batch_wav, batch_spec

In [42]:
clean_wav, clean_spec = Load(root_sample,'clean')
estim_wav, estim_spec = Load(root_sample,'estim')
output_wav,output_spec = Load(root_sample,'DCUNET_t28')
print(clean_wav.shape)
print(clean_spec.shape)
print(estim_wav.shape)
print(estim_spec.shape)
print(output_wav.shape)
print(output_spec.shape)

torch.Size([10, 61440])
torch.Size([10, 513, 240, 2])
torch.Size([10, 61440])
torch.Size([10, 513, 240, 2])
torch.Size([10, 61440])
torch.Size([10, 513, 240, 2])


# Scale inspect

In [95]:
print(estim_spec.shape)

torch.Size([10, 513, 240, 2])


## Loss

In [192]:
# wav domain
def SDRLoss(output,target):
    xy = torch.diag(output @ target.t())
    yy = torch.diag(target @ target.t())
    xx = torch.diag(output @ output.t())

    SDR = xy**2/ (yy*xx - xy**2 )
    return torch.mean(SDR) 

def iSDRLoss(output,target):
    sdr = SDRLoss(output,target)
    return 1/sdr

MSE = torch.nn.MSELoss() 
L1  = MSE = torch.nn.L1Loss() 

## Weight Error 
def t1(output,target,alpha = 0.5,is_wav=False):
    
    if is_wav : 
        #s_abs     = target.abs()
        #s_hat_abs = output.abs()
        d = target - output
    else : 
        #print(target.shape)
        s_abs = torch.sqrt(target[:,:,:,0]**2 + target[:,:,:,1]**2)
        s_hat_abs = torch.sqrt(output[:,:,:,0]**2 + output[:,:,:,1]**2)
        
        s_abs = s_abs/torch.mean(s_abs)
        s_hat_abs = s_hat_abs/torch.mean(s_hat_abs)
        
        d = s_abs - s_hat_abs
    return torch.mean(alpha *(d + d.abs())/2 + (1-alpha) * (d-d.abs())/2)

mel_basis = librosa.filters.mel(sr=16000, n_fft=1024,n_mels=40)
mel_basis = torch.from_numpy(mel_basis)

## Meldomain Weighted Error
def t2(output,target,alpha=0.7):
    s_mag = torch.sqrt(target[:,:,:,0]**2 + target[:,:,:,1]**2)
    s_hat_mag = torch.sqrt(output[:,:,:,0]**2 + output[:,:,:,1]**2)
    
    #s_mag = s_mag/torch.mean(s_mag,dim=(1,2)).view(10, 1,1)
    #s_hat_mag = s_hat_mag/torch.mean(s_mag,dim=(1,2)).view(10, 1,1)
    
    s_mag = s_mag/torch.mean(s_mag)
    s_hat_mag = s_hat_mag/torch.mean(s_mag)
    
    
   # print(mel_basis.shape)
    #print(s_mag.shape)
    s = torch.matmul(mel_basis,s_mag)
    #print(s.shape)
    s_hat = torch.matmul(mel_basis,s_hat_mag)
    d = s - s_hat
    return torch.mean(alpha *(d + d.abs())/2 + (1-alpha) * ((d-d.abs())).abs()/2)

## Weight on Scale Error
def t3(output,target):
    pass

In [193]:
print(loss + ', spec , estim  , 0.9, ' + str(t2(estim_spec,clean_spec,alpha=0.9)))
print(loss + ', spec , output , 0.9, ' + str(t2(output_spec,clean_spec,alpha=0.9))) 

LSD, spec , estim  , 0.9, tensor(0.0494)
LSD, spec , output , 0.9, tensor(0.0851)


In [191]:
print(loss + ', spec , estim  , 0.9, ' + str(t2(estim_spec,clean_spec,alpha=0.9)))
print(loss + ', spec , output , 0.9, ' + str(t2(output_spec,clean_spec,alpha=0.9))) 

LSD, spec , estim  , 0.9, tensor(0.0488)
LSD, spec , output , 0.9, tensor(0.0852)


### On Optimal Frequency-Domain Multichannel Linear Filtering for Noise Reduction  
https://ieeexplore.ieee.org/abstract/document/5089420  



In [158]:
# Resudal Signal Distortion
def RSD(output,target) :
    target - output

# Local Signal Distortion Index
def LSD(output,target) : 
    d = target-output
    tmp = d*torch.conj(d)/target*torch.conj(target)
    return torch.mean(tmp.abs()[:,:,:,0])

# Local Noise Reduction factor
def LNO():
    pass

In [145]:
ttt = LSD(estim_spec,clean_spec)
print(ttt.shape)
print(ttt)

torch.Size([])
tensor(3.0670)


In [126]:
x = torch.rand(1,dtype=torch.cfloat)
print(x)

tensor([0.9994+0.4754j])


In [128]:
x = x * torch.conj(x)
print(x)

tensor([1.2249+1.2707e-08j])


In [130]:
x.abs()

tensor([1.2249])

# Get Loss  

## IVA estim VS IVA output

# CSV

In [172]:
list_run = ['iSDR', 'MSE ', 't1  ','t2','LSD']

for loss in list_run : 
    if loss == 'iSDR' :
        print(loss + ', wav  , estim  , ' + str(iSDRLoss(estim_wav,clean_wav)))
        print(loss + ', wav  , output , ' + str(iSDRLoss(output_wav,clean_wav)))  
    elif loss == 'MSE ':
        print(loss + ', wav  , estim  , ' + str(MSE(estim_wav,clean_wav)))
        print(loss + ', wav  , output , ' + str(MSE(output_wav,clean_wav)))  
        print(loss + ', spec , estim  , ' + str(MSE(estim_spec,clean_spec)))
        print(loss + ', spec , output , ' + str(MSE(output_spec,clean_spec)))  
    elif loss == 't1  ' :
        print(loss + ', wav  , estim  , 0.9, ' + str(t1(estim_wav,clean_wav,alpha=0.9,is_wav=True)))
        print(loss + ', wav  , output , 0.9, ' + str(t1(output_wav,clean_wav,alpha=0.9,is_wav=True)))  
        print(loss + ', spec , estim  , 0.9, ' + str(t1(estim_spec,clean_spec,alpha=0.9)))
        print(loss + ', spec , output , 0.9, ' + str(t1(output_spec,clean_spec,alpha=0.9))) 
        
        print(loss + ', wav  , estim  , 0.5, ' + str(t1(estim_wav,clean_wav,alpha=0.5,is_wav=True)))
        print(loss + ', wav  , output , 0.5, ' + str(t1(output_wav,clean_wav,alpha=0.5,is_wav=True)))  
        print(loss + ', spec , estim  , 0.5, ' + str(t1(estim_spec,clean_spec,alpha=0.5)))
        print(loss + ', spec , output , 0.5, ' + str(t1(output_spec,clean_spec,alpha=0.5)))  
        
        print(loss + ', wav  , estim  , 0.1, ' + str(t1(estim_wav,clean_wav,alpha=0.1,is_wav=True)))
        print(loss + ', wav  , output , 0.1, ' + str(t1(output_wav,clean_wav,alpha=0.1,is_wav=True)))  
        print(loss + ', spec , estim  , 0.1, ' + str(t1(estim_spec,clean_spec,alpha=0.1)))
        print(loss + ', spec , output , 0.1, ' + str(t1(output_spec,clean_spec,alpha=0.1)))  
    elif loss == 't2' :
        print(loss + ', spec , estim  , 1.0, ' + str(t2(estim_spec,clean_spec,alpha=1.0)))
        print(loss + ', spec , output , 1.0, ' + str(t2(output_spec,clean_spec,alpha=1.0)))
        
        print(loss + ', spec , estim  , 0.99, ' + str(t2(estim_spec,clean_spec,alpha=0.99)))
        print(loss + ', spec , output , 0.99, ' + str(t2(output_spec,clean_spec,alpha=0.99)))
        
        print(loss + ', spec , estim  , 0.9, ' + str(t2(estim_spec,clean_spec,alpha=0.9)))
        print(loss + ', spec , output , 0.9, ' + str(t2(output_spec,clean_spec,alpha=0.9))) 
        
        print(loss + ', spec , estim  , 0.5, ' + str(t2(estim_spec,clean_spec,alpha=0.5)))
        print(loss + ', spec , output , 0.5, ' + str(t2(output_spec,clean_spec,alpha=0.5))) 
        
        print(loss + ', spec , estim  , 0.1, ' + str(t2(estim_spec,clean_spec,alpha=0.1)))
        print(loss + ', spec , output , 0.1, ' + str(t2(output_spec,clean_spec,alpha=0.1))) 
    elif loss =='LSD' :
        print(loss + ', spec , estim  , ' + str(LSD(estim_spec,clean_spec)))
        print(loss + ', spec , output , ' + str(LSD(output_spec,clean_spec)))  

iSDR, wav  , estim  , tensor(2.0017)
iSDR, wav  , output , tensor(0.8596)
MSE , wav  , estim  , tensor(0.0924)
MSE , wav  , output , tensor(0.0232)
MSE , spec , estim  , tensor(0.4533)
MSE , spec , output , tensor(0.1473)
t1  , wav  , estim  , 0.9, tensor(0.0370)
t1  , wav  , output , 0.9, tensor(0.0093)
t1  , spec , estim  , 0.9, tensor(0.2989)
t1  , spec , output , 0.9, tensor(0.2089)
t1  , wav  , estim  , 0.5, tensor(8.5347e-05)
t1  , wav  , output , 0.5, tensor(1.9859e-05)
t1  , spec , estim  , 0.5, tensor(-8.3284e-09)
t1  , spec , output , 0.5, tensor(-2.8158e-08)
t1  , wav  , estim  , 0.1, tensor(-0.0369)
t1  , wav  , output , 0.1, tensor(-0.0093)
t1  , spec , estim  , 0.1, tensor(-0.2989)
t1  , spec , output , 0.1, tensor(-0.2089)
t2, spec , estim  , 1.0, tensor(0.0514)
t2, spec , output , 1.0, tensor(0.0946)
t2, spec , estim  , 0.99, tensor(0.0512)
t2, spec , output , 0.99, tensor(0.0936)
t2, spec , estim  , 0.9, tensor(0.0494)
t2, spec , output , 0.9, tensor(0.0851)
t2, spec ,