Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about sdr #3

Open
KiAlexander opened this issue May 27, 2020 · 2 comments
Open

about sdr #3

KiAlexander opened this issue May 27, 2020 · 2 comments

Comments

@KiAlexander
Copy link

KiAlexander commented May 27, 2020

I try to test my codes which calculate sdr with your separate samples(ex_18).

In my sdr codes, the result is about 6.47 while yours is 19.37.

can you help me find out anything wrong in my codes? Thx.

the codes are as follows.

`#!/usr/bin/env python

import soundfile as sf
from mir_eval.separation import bss_eval_sources
import numpy as np

import torch

from itertools import permutations

def cal_SDRi(src_ref, src_est, mix):
    # Calculate Source-to-Distortion Ratio improvement (SDRi).
    # NOTE: bss_eval_sources is very very slow.
    # Args:
    #     src_ref: numpy.ndarray, [C, T]
    #     src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
    #     mix: numpy.ndarray, [T]
    # Returns:
    #     average_SDRi

    src_anchor = np.stack([mix, mix], axis=0)
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
    return avg_SDRi

def cal_SISNRi(src_ref, src_est, mix):
    # Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
    # Args:
    #     src_ref: numpy.ndarray, [C, T]
    #     src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
    #     mix: numpy.ndarray, [T]
    # Returns:
    #     average_SISNRi
    # 
    sisnr1 = cal_SISNR(src_ref[0], src_est[0])
    sisnr2 = cal_SISNR(src_ref[1], src_est[1])
    sisnr1b = cal_SISNR(src_ref[0], mix)
    sisnr2b = cal_SISNR(src_ref[1], mix)
    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
    return avg_SISNRi

def cal_SISNR(ref_sig, out_sig, eps=1e-8):
    # Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
    # Args:
    #     ref_sig: numpy.ndarray, [T]
    #     out_sig: numpy.ndarray, [T]
    # Returns:
    #     SISNR

    assert len(ref_sig) == len(out_sig)
    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)
    ref_energy = np.sum(ref_sig ** 2) + eps
    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
    noise = out_sig - proj
    ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
    return sisnr


def calc_sdr(estimation, origin):

    # batch-wise SDR caculation for one audio file.
    # estimation: (batch, nsample)
    # origin: (batch, nsample)

      
    origin_power = np.sum(origin**2, 1, keepdims=True) + 1e-8  # (batch, 1)
    
    scale = np.sum(origin*estimation, 1, keepdims=True) / origin_power  # (batch, 1)
    
    est_true = scale * origin  # (batch, nsample)
    est_res = estimation - est_true  # (batch, nsample)
    
    true_power = np.sum(est_true**2, 1)
    res_power = np.sum(est_res**2, 1)
    
    return 10*np.log10(true_power) - 10*np.log10(res_power)  # (batch, 1)

def compute_measures(se,s,j):
    Rss=s.transpose().dot(s)
    this_s=s[:,j]

    a=this_s.transpose().dot(se)/Rss[j,j]
    e_true=a*this_s
    e_res=se-a*this_s
    Sss=np.sum((e_true)**2)
    Snn=np.sum((e_res)**2)

    SDR=10*np.log10(Sss/Snn)

    Rsr= s.transpose().dot(e_res)
    b=np.linalg.inv(Rss).dot(Rsr)

    e_interf = s.dot(b)
    e_artif= e_res-e_interf

    SIR=10*np.log10(Sss/np.sum((e_interf)**2))
    SAR=10*np.log10(Sss/np.sum((e_artif)**2))
    return SDR, SIR, SAR

def GetSDR(se,s):
    se = se.transpose()
    s = s.transpose()

    se=se-np.mean(se,axis=0)
    s=s-np.mean(s,axis=0)
    nsampl,nsrc=se.shape
    nsampl2,nsrc2=s.shape

    assert(nsrc2==nsrc)
    assert(nsampl2==nsampl)

    SDR=np.zeros((nsrc,nsrc))
    SIR=SDR.copy()
    SAR=SDR.copy()

    for jest in range(nsrc):
        for jtrue in range(nsrc):
            SDR[jest,jtrue],SIR[jest,jtrue],SAR[jest,jtrue]=compute_measures(se[:,jest],s,jtrue)


    perm=list(permutations(np.arange(nsrc)))
    nperm=len(perm)
    meanSIR=np.zeros((nperm,))
    for p in range(nperm):
        tp=SIR.transpose().reshape(nsrc*nsrc)
        idx=np.arange(nsrc)*nsrc+list(perm[p])
        meanSIR[p]=np.mean(tp[idx])
    popt=np.argmax(meanSIR)
    per=list(perm[popt])
    idx=np.arange(nsrc)*nsrc+per
    SDR=SDR.transpose().reshape(nsrc*nsrc)[idx]
    SIR=SIR.transpose().reshape(nsrc*nsrc)[idx]
    SAR=SAR.transpose().reshape(nsrc*nsrc)[idx]
    return SDR, SIR, SAR, per

EPS = 1e-8


def cal_si_snr_with_pit(source, estimate_source, source_lengths):
    # Calculate SI-SNR with PIT training.
    # Args:
    #     source: [B, C, T], B is batch size
    #     estimate_source: [B, C, T]
    #     source_lengths: [B], each item is between [0, T]

    assert source.size() == estimate_source.size()
    B, C, T = source.size()

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)  # [B, C, C]
    print('sisnr:',pair_wise_si_snr)

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, perms, max_snr_idx

def _sdr( y, z, SI=False):
    if SI:
        a = ((z*y).mean(-1) / (y*y).mean(-1)).unsqueeze(-1) * y
        return 10*torch.log10( (a**2).mean(-1) / ((a-z)**2).mean(-1))
    else:
        return 10*torch.log10( (y*y).mean(-1) / ((y-z)**2).mean(-1))

def test():   
    mix = sf.read('./ex_18/mixture.wav')[0]
    source = np.stack([sf.read('./ex_18/s1.wav')[0], sf.read('./ex_18/s2.wav')[0]], axis=0)
    estimate_source = np.stack([sf.read('./ex_18/s1_estimate.wav')[0], sf.read('./ex_18/s2_estimate.wav')[0]], axis=0)

    SDRi =cal_SDRi(source,estimate_source,mix)
    SISNRi = cal_SISNRi(source,estimate_source,mix)
    print('SDRi:{}'.format(SDRi))
    print('SISNRi:{}\n'.format(SISNRi))

    sdr1 = calc_sdr(source, estimate_source)
    sdr2 = calc_sdr(source, np.stack([mix, mix], axis=0))
    sdri = np.mean(sdr1-sdr2)
    print('sdr1:{}'.format(sdr1))
    print('sdr2:{}'.format(sdr2))
    print('sdri:{}\n'.format(sdri))

    SDR, SIR, SAR, per = GetSDR(estimate_source, source)
    print('SDR:{}\nSIR:{}\nSAR:{}\nper:{}\n'.format(SDR, SIR, SAR, per))

    source_lengths = torch.from_numpy(np.array([mix.shape]))
    max_snr, _, _ = cal_si_snr_with_pit(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),source_lengths)
    print('max_snr:{}\n'.format(max_snr))

    SISDR = _sdr(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),SI=True)
    print('SISDR: ',SISDR)


if __name__ == '__main__':
    test()

And ouput:

    #     SDRi:7.892910056532607
    #     SISNRi:7.151290758024819

    #     sdr1:[6.68677879 6.25589817]
    #     sdr2:[-1.18234941 -0.17591568]
    #     sdri:7.150471030776565

    #     SDR:[6.68712455 6.25720926]
    #     SIR:[34.92616321 34.92229867]
    #     SAR:[6.69364393 6.26311903]
    #     per:[0, 1]

    #     sisnr: tensor([[[  6.6871, -34.0962],
    #              [-34.1721,   6.2572]]])
    #     max_snr:  tensor([[6.4722]])

    #     SISDR:  tensor([[6.6868, 6.2559]])

    # ex_18/metrics.json
    # {
    # "input_si_sdr": 0.028149127960205078,
    # "input_sdr": 0.15109104033014964,
    # "input_sir": 0.1510910403301708,
    # "input_sar": 144.89122580687916,
    # "input_stoi": 0.7178163832006375,
    # "input_pesq": 1.599277138710022,
    # "si_sdr": 19.083293914794922,
    # "sdr": 19.376235432506704,
    # "sir": 30.187015165321924,
    # "sar": 19.759935974444744,
    # "stoi": 0.9568062920227058,
    # "pesq": 3.562618613243103,
    # "mix_path": "/mnt/data/wham/wav8k/min/tt/mix_clean/050a050c_0.050237_442c020j_-0.050237.wav"
    # }
@KiAlexander
Copy link
Author

and the result calculated by pb_bss_eval

{'input_pesq': 1.5960057377815247,
 'input_sar': 11.243911897495014,
 'input_sdr': -0.5471245564855747,
 'input_si_sdr': -0.7163815595640699,
 'input_sir': 0.14099714373415928,
 'input_stoi': 0.662681954751808,
 'pesq': 2.596807837486267,
 'sar': 6.761317115018516,
 'sdr': 6.659954133819946,
 'si_sdr': 5.684012353271584,
 'sir': 23.99486035934128,
 'stoi': 0.8663816779000003}

@etzinis
Copy link
Owner

etzinis commented May 29, 2020

Hey thanks for reaching out. I am kind of trying to catch a deadline so I did not check my github issues. It seems that your code is fine. Surprisingly there is a bug probably with the uploaded sources. If you listen to both the estimates and the actual sources you could actually hear that the sources and the mixture sound very noisy. However, the estimation seems to be quite better quality with much less artifacts. Moreover, I have actually used this code to produce some audio examples https://github.com/mpariente/asteroid/tree/master/egs/wham/TwoStep which might also contain this noise because of the wham dataset. I am gonna take a look at this, hopefully in a few days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants