In [4]:
import os
from utils import snr_db, resample_signal
import numpy as np
from scipy.io import wavfile
from methods import wiener_filter, spectral_subtraction
from pesq import pesq
from pystoi import stoi
import matplotlib.pyplot as plt

In [None]:
directory = os.fsencode("noisy_testset_wav")

counter = 0
pesq_wiener = 0
pesq_noisy = 0
pesq_sub = 0
snr_wiener = 0
snr_noisy = 0
snr_sub = 0
stoi_wiener = 0
stoi_noisy = 0
stoi_sub = 0

for file in os.listdir(directory):
    filename = os.fsdecode(file)
    try:
        sr_noisy, noisy = wavfile.read(f"noisy_testset_wav/{filename}")
        sr_clean, clean = wavfile.read(f"clean_testset_wav/{filename}")
        noisy = noisy.astype(np.float32)
        clean = clean.astype(np.float32)
        noisy = resample_signal(noisy, sr_noisy, 16000)
        clean = resample_signal(clean, sr_clean, 16000)

        noise_estimate = noisy[:int(0.5 * 16000)]
        wiener_signal = wiener_filter(noisy, noise_estimate)
        spectral_sub = spectral_subtraction(noisy, noise_estimate)

        pesq_noisy += pesq(16000, clean, noisy)
        pesq_wiener += pesq(16000, clean, wiener_signal)
        pesq_sub += pesq(16000, clean, spectral_sub)
        snr_noisy += snr_db(clean, noisy)
        snr_wiener += snr_db(clean, wiener_signal)
        snr_sub += snr_db(clean, spectral_sub)
        stoi_noisy += stoi(clean, noisy, 16000)
        stoi_wiener += stoi(clean, wiener_signal, 16000)
        stoi_sub += stoi(clean, spectral_sub, 16000)
        counter += 1

    except:
        print(filename)

In [None]:
pesq_noisy / counter, pesq_wiener / counter, pesq_sub / counter

In [None]:
snr_noisy / counter, snr_wiener / counter, snr_sub / counter

In [None]:
stoi_noisy / counter, stoi_wiener / counter, stoi_sub / counter

In [None]:
directory = os.fsencode("segan_results")
counter = 0
pesq_segan = 0
snr_segan = 0
stoi_segan = 0

for file in os.listdir(directory):
    filename = os.fsdecode(file)
    try:
        sr_segan,segan = wavfile.read(f"segan_results/{filename}")
        sr_clean, clean = wavfile.read(f"clean_testset_wav/{filename[4:]}")
        segan = segan.astype(np.float32)
        clean = clean.astype(np.float32)
        segan = resample_signal(segan, sr_segan, 16000)
        clean = resample_signal(clean, sr_clean, 16000)
        pesq_segan += pesq(16000, clean, segan)
        snr_segan += snr_db(clean, segan)
        stoi_segan += stoi(clean, segan, 16000)
        counter += 1

    except:
        print(filename)

In [None]:
pesq_segan / counter, snr_segan / counter, stoi_segan / counter

In [8]:
sr_adv_segan_full, adv_segan_full = wavfile.read(f"../raw_pred_adv_segan.wav")
adv_segan_full = adv_segan_full.astype(np.float32) / 32768.0 
adv_segan_full = resample_signal(adv_segan_full, sr_adv_segan_full, 16000)

directory = os.fsencode("../clean_testset_wav")

start = 0
counter = 0
pesq_adv_segan = 0
snr_adv_segan = 0
stoi_adv_segan = 0

for file in sorted(os.listdir(directory)):
    filename = os.fsdecode(file)

    sr_clean, clean = wavfile.read(f"../clean_testset_wav/{filename}")
    clean = clean.astype(np.float32)
    clean = resample_signal(clean, sr_clean, 16000)
    adv_segan = adv_segan_full[start:start+len(clean)]
    start += len(clean)

    pesq_sample = pesq(16000, clean, adv_segan)
    pesq_adv_segan += pesq_sample
    snr_adv_segan += snr_db(clean, adv_segan)
    stoi_adv_segan += stoi(clean, adv_segan, 16000)

    wavfile.write(f"../segan_adv_results/{pesq_sample}__{filename}", 16000, adv_segan)

    counter += 1

In [None]:
directory = os.fsencode("wavenet_results/enhanced")

counter = 0
pesq_wavenet = 0
snr_wavenet = 0
stoi_wavenet = 0

for file in sorted(os.listdir(directory)):
    filename = os.fsdecode(file)
    try:
        sr_wavenet,wavenet = wavfile.read(f"wavenet_results/enhanced/{filename}")
        sr_clean, clean = wavfile.read(f"clean_testset_wav/{filename}")
        wavenet = wavenet.astype(np.float32)
        clean = clean.astype(np.float32)
        wavenet = resample_signal(wavenet, sr_wavenet, 16000)
        clean = resample_signal(clean, sr_clean, 16000)
        pesq_wavenet += pesq(16000, clean, wavenet)
        # snr_wavenet += snr_db(clean, wavenet)
        print(pesq(16000, clean, wavenet))
        # stoi_wavenet += stoi(clean, wavenet, 16000)
        counter += 1

    except:
        print(filename)


In [None]:
print(counter)

In [None]:
pesq_wavenet / counter, snr_wavenet / counter, stoi_wavenet / counter

In [None]:
pesq_wavenet / counter