In [None]:
!git clone https://github.com/pytorch/audio.git

In [1]:
import sys
sys.path.append('./audio/examples/')

In [2]:
from beamforming.mvdr import MVDR
import torch
import torchaudio
import IPython.display as ipd

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



In [3]:
mix, sr = torchaudio.load('./wavs/mix.wav')
reverb_clean, sr2 = torchaudio.load('./wavs/reverb_clean.wav')
clean, sr3 = torchaudio.load('./wavs/clean.wav')
assert sr == sr2
noise = mix - reverb_clean
mix = mix.to(torch.double)
noise = noise.to(torch.double)
clean = clean.to(torch.double)

In [4]:
stft = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, return_complex=True, power=None)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)

### Compute the complex-valued STFT of mixture, clean speech, and noise

In [5]:
spec_mix = stft(mix)
spec_clean = stft(clean)
spec_noise = stft(noise)

### Generate the Ideal Ratio Mask (IRM). Floor the mask value to be [0, 1]

In [6]:
def get_irms(spec_clean, spec_noise, spec_mix):
    mag_mix = spec_mix.abs() ** 2
    mag_clean = spec_clean.abs() ** 2
    mag_noise = spec_noise.abs() ** 2
    irm_speech = mag_clean / (mag_mix + 1e-8)
    irm_noise = mag_noise / (mag_mix + 1e-8)
    irm_speech[irm_speech>=1] =1.
    irm_noise[irm_noise>=1] = 1.
    return irm_speech, irm_noise

In [7]:
irm_speech, irm_noise = get_irms(spec_clean, spec_noise, spec_clean)

### Apply MVDR beamforming by using multi-channel masks

In [8]:
results_multi = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']:
    mvdr = MVDR(ref_channel=0, solution=solution, multi_mask=True)
    stft_est = mvdr(spec_mix, irm_speech, irm_noise)
    est = istft(stft_est, length=mix.shape[-1])
    results_multi[solution] = est

### Apply MVDR beamforming by using single-channel masks

In [9]:
results_single = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']:
    mvdr = MVDR(ref_channel=0, solution=solution, multi_mask=False)
    stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
    est = istft(stft_est, length=mix.shape[-1])
    results_single[solution] = est

### Compute Si-SDR scores

In [10]:
from source_separation.utils.metrics import sdr
def si_sdr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return sisdr.item()

### Single-channel mask results

In [11]:
for solution in results_single:
    print(solution+": ", si_sdr(results_single[solution][None,...], clean[0:1]))

ref_channel:  8.771621920879326
stv_evd:  7.344749326823148
stv_power:  8.453569938642666


### Multi-channel mask results

In [12]:
for solution in results_multi:
    print(solution+": ", si_sdr(results_multi[solution][None,...], clean[0:1]))

ref_channel:  8.346077085786277
stv_evd:  6.9351273105855
stv_power:  7.828446229086943


### Display the mixture audio

In [13]:
print("Mixture speech")
ipd.Audio(mix[0], rate=16000)

Mixture speech


### Display the noise

In [14]:
print("Noise")
ipd.Audio(noise[0], rate=16000)

Noise


### Display the clean speech

In [15]:
print("Clean speech")
ipd.Audio(clean[0], rate=16000)

Clean speech


### Display the enhanced audios

In [16]:
print("multi-channel mask, ref_channel solution")
ipd.Audio(results_multi['ref_channel'], rate=16000)

multi-channel mask, ref_channel solution


In [17]:
print("multi-channel mask, stv_evd solution")
ipd.Audio(results_multi['stv_evd'], rate=16000)

multi-channel mask, stv_evd solution


In [18]:
print("multi-channel mask, stv_power solution")
ipd.Audio(results_multi['stv_power'], rate=16000)

multi-channel mask, stv_power solution


In [19]:
print("single-channel mask, ref_channel solution")
ipd.Audio(results_single['ref_channel'], rate=16000)

single-channel mask, ref_channel solution


In [20]:
print("single-channel mask, stv_evd solution")
ipd.Audio(results_single['stv_evd'], rate=16000)

single-channel mask, stv_evd solution


In [21]:
print("single-channel mask, stv_power solution")
ipd.Audio(results_single['stv_power'], rate=16000)

single-channel mask, stv_power solution
