In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchaudio
from df import enhance, init_df
from os.path import join as pjoin
import matplotlib.pyplot as plt
import os
import tqdm
import time
from IPython.display import Audio #listen: ipd.Audio(real.detach().cpu().numpy(), rate=FS)
import numpy as np
import scipy.signal as sig
import pandas as pd

  from torchaudio.backend.common import AudioMetaData


In [2]:
np.random.seed(0)
FS=48000
DURATION = 2

In [3]:
def extend_signal(signal, target_length):
    """
    Extend a signal by repeating it if it's shorter than the target length.
    
    Args:
    signal (torch.Tensor): Input signal.
    target_length (int): Desired length of the extended signal.

    Returns:
    torch.Tensor: Extended signal.
    """
    current_length = signal.size(0)
    if current_length < target_length:
        repetitions = target_length // current_length
        remainder = target_length % current_length
        extended_signal = signal.repeat(repetitions)
        if remainder > 0:
            extended_signal = torch.cat((extended_signal, signal[:remainder]), dim=0)
        return extended_signal
    else:
        return signal

def load_audio(apath):
    audio, fs = torchaudio.load(apath)
    if fs != FS:
        #print('resampling')
        resampler = torchaudio.transforms.Resample(fs, FS)
        audio = resampler(audio)    
    if len(audio.shape) > 1:
            audio = audio[0,:]
    return audio

def power(signal):
    return np.mean(signal**2)

In [4]:
speech_path = '/home/ubuntu/Data/DFN/textfiles/readspeech_set.txt'
noise_path = '/home/ubuntu/Data/DFN/textfiles/test_set_noise.txt'
rir_path = '/home/ubuntu/Data/DFN/textfiles/real_rirs.txt'
#speakerphone_path = '/home/ubuntu/Data/DFN/textfiles/DNS5_val_speakerphone.txt'
#headset_path = '/home/ubuntu/Data/DFN/textfiles/DNS5_val_headset.txt'


In [5]:
# load speech wav paths from the textfile
speech_paths = []
with open(speech_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        speech_paths.append(line.rstrip()) 
print('speech set loaded. contains '+str(len(speech_paths)) +' files.')

speech set loaded. contains 41194 files.


In [6]:
# load speech wav paths from the textfile
noise_paths = []
with open(noise_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        noise_paths.append(line.rstrip()) 

In [7]:
noise_paths_val = []
noise_paths_train = []
with open('/home/ubuntu/Data/DFN/textfiles/training_set_noise.txt', 'r') as file:
    lines = file.readlines()
    for line in lines:
        noise_paths_train.append(line.rstrip()) 
with open('/home/ubuntu/Data/DFN/textfiles/validation_set_noise.txt', 'r') as file:
    lines = file.readlines()
    for line in lines:
        noise_paths_val.append(line.rstrip()) 

In [8]:
# load speech wav paths from the textfile
rir_paths = []
with open(rir_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        rir_paths.append(line.rstrip()) 

In [9]:
ns = []
for n in tqdm.tqdm(noise_paths):
    if n not in noise_paths_val:
        if n not in noise_paths_train:
            ns.append(n)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9565/9565 [00:09<00:00, 1021.86it/s]


In [10]:
noise_paths = ns

In [11]:
TRAINRIR_NAMES = {'D01_sb_none_NH_mono': 'singleband' , 'D02_mb_none_NH_mono': 'multiband', 
            'D03_mb_rec_NH_left': 'recdirectivity', 'D05_mb_srcrec_NH_left': 'recsourcedirectivity',
            'D00_DNS5': 'DNS5', 'D09_SSmp3d_left' : 'soundspaces'}

use_gpu = True
if torch.cuda.is_available() and use_gpu:
    TORCH_DEVICE = "cuda"
else:
    TORCH_DEVICE = "cpu"

model_names = list(TRAINRIR_NAMES.keys())


In [12]:
# we randomly pick 40 speech, noise and real rirs

In [13]:
np.random.seed(0)

In [14]:
speech_paths = np.random.choice(speech_paths, 101)
noise_paths = np.random.choice(noise_paths, 101)
rir_paths = np.random.choice(rir_paths, 101)

In [15]:
#we remove a corrupt file

In [16]:
rir_paths = np.delete(rir_paths, rir_paths=='/home/ubuntu/Data/MIT_IR_Survey/h195_Outside_SuburbanFronyYard_1txts.wav')

In [17]:
speech_paths = np.delete(speech_paths, speech_paths=='/home/ubuntu/Data/DNS-Challenge/datasets_fullband/clean_fullband/read_speech/book_02476_chp_0010_reader_09190_12_seg_1.wav')
noise_paths = np.delete(noise_paths, noise_paths=='/home/ubuntu/Data/DNS-Challenge/datasets_fullband/noise_fullband/door_Freesound_validated_470511_7.wav')

In [18]:
snrs = np.linspace(0, 30, 8)

In [19]:
for model_name in model_names:
    model_path = pjoin('/home/ubuntu/Data/DFN', model_name)
    model, df_state, _ = init_df(model_path)
    
    s = snrs.repeat(int(np.ceil(len(speech_paths)/len(snrs))))
    path = pjoin('listening_test_drynoise')
    for i, speech_pth in enumerate(speech_paths):
        snr = s[i]
        clean = load_audio(speech_paths[i])
        if i < 10:
            idx = '0'+str(i)    
        else:
            idx = str(i)
        noise = load_audio(noise_paths[i])
        try:
            rir = load_audio(rir_paths[i])
        except:
            np.random.seed(0)
            rir = load_audio(np.random.choice(rir_paths, 1)[0])
    
        # we extend speech and noise if too short
        if len(clean) < FS * DURATION:
            clean = extend_signal(clean, FS*DURATION)
        if len(noise) < FS * DURATION:
            noise = extend_signal(noise, FS*DURATION)
                       
        # back to numpy for easy conv
        clean = clean.numpy()
        noise = noise.numpy()
        rir = rir.numpy()
            
        # we choose the signal chunk with more energy (to avoid silent chunks)
        nchunks = len(clean) // (FS*DURATION)
        chunks = np.split(clean[: FS * DURATION * nchunks], nchunks)
        powers = np.array([power(x) for x in chunks])
        clean = clean[np.argmax(powers) * FS * DURATION : (np.argmax(powers) + 1 ) *  FS * DURATION]
        
        nchunks = len(noise) // (FS*DURATION)
        chunks = np.split(noise[: FS * DURATION * nchunks], nchunks)
        powers = np.array([power(x) for x in chunks])
        noise = noise[np.argmax(powers) * FS * DURATION : (np.argmax(powers) + 1 ) *  FS * DURATION]
    
        #handle silent noise
        noise_nrgy = power(noise)
        if noise_nrgy == 0.:
            #print('silent noise sample, using white noise')
            noise = np.random.randn( FS * DURATION )
    
        # we set the SNR
        ini_snr = 10 * np.log10(power(clean) / power(noise))
        noise_gain_db = ini_snr - snr
        noise *= np.power(10, noise_gain_db/20)
    
        # we normalize to 0.9 if mixture is close to clipping
        clips = np.max(np.abs(clean + noise))
        if clips >= 0.9:
            clips /= 0.9
            noise /= clips
            clean /= clips
        # or to -18dBfs if smaller than that:
        elif clips <= 10**(-18/20):
            clips /= 10**(-18/20)
            noise /= clips 
            clean /= clips    
    
        # apply rir 
        revspeech = sig.fftconvolve(clean, rir, 'full')
        # synchronize reverberant with anechoic
        lag = np.where(np.abs(rir) >= 0.5*np.max(np.abs(rir)))[0][0] # we take as direct sound the first value (from the left) that's at most -6dB from max
    
        revspeech = revspeech[lag:FS*DURATION + lag]
    
        # enforce energy conservation
        revspeech *= np.sqrt(power(clean) / power(revspeech)) 
    
        #apply RIR to noise too if needed
        #if self.reverberant_noises:
        #rnoise = sig.fftconvolve(noise, rir, 'full')
        #rnoise = rnoise[lag:FS*DURATION + lag]
        #rnoise *= np.sqrt(power(noise) / power(rnoise))
        #noise = rnoise
        noisy = revspeech + noise
        noisy = torch.from_numpy(noisy)
        enhanced = enhance(model, df_state, noisy.unsqueeze(0))
        enhanced *= np.sqrt(power(clean) / power(enhanced.numpy())) 
        if torch.max(enhanced) > 1.0:
            print(i)
            print('clipping...')
            clip_factor = torch.max(enhanced)
            enhanced/=clip_factor
            clean/=clip_factor.item()
        torchaudio.save(pjoin(path, idx+'_snr_'+str(int(np.round(snr)))+'_noisy.flac'), noisy.unsqueeze(0), FS)
        torchaudio.save(pjoin(path, idx+'_snr_'+str(int(np.round(snr)))+'_clean.flac'), torch.from_numpy(clean).unsqueeze(0), FS)
        torchaudio.save(pjoin(path, idx+'_snr_'+str(int(np.round(snr)))+'_'+model_name+'.flac'), enhanced, FS)
print('Done.')

[32m2025-04-14 13:17:07[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on torch 2.1.1+cu121[0m
[32m2025-04-14 13:17:07[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on host op-mm-guestxr[0m
[32m2025-04-14 13:17:07[0m | [1mINFO    [0m | [36mDF[0m | [1mLoading model settings of D01_sb_none_NH_mono[0m
[32m2025-04-14 13:17:07[0m | [1mINFO    [0m | [36mDF[0m | [1mInitializing model `deepfilternet3`[0m


fatal: not a git repository (or any of the parent directories): .git


[32m2025-04-14 13:17:10[0m | [1mINFO    [0m | [36mDF[0m | [1mFound checkpoint /home/ubuntu/Data/DFN/D01_sb_none_NH_mono/checkpoints/model_118.ckpt.best with epoch 118[0m
[32m2025-04-14 13:17:10[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on device cuda:0[0m
[32m2025-04-14 13:17:10[0m | [1mINFO    [0m | [36mDF[0m | [1mModel loaded[0m
76
clipping...
[32m2025-04-14 13:17:34[0m | [1mINFO    [0m | [36mDF[0m | [1mLoading model settings of D02_mb_none_NH_mono[0m
[32m2025-04-14 13:17:34[0m | [1mINFO    [0m | [36mDF[0m | [1mInitializing model `deepfilternet3`[0m
[32m2025-04-14 13:17:34[0m | [1mINFO    [0m | [36mDF[0m | [1mFound checkpoint /home/ubuntu/Data/DFN/D02_mb_none_NH_mono/checkpoints/model_116.ckpt.best with epoch 116[0m
[32m2025-04-14 13:17:35[0m | [1mINFO    [0m | [36mDF[0m | [1mRunning on device cuda:0[0m
[32m2025-04-14 13:17:35[0m | [1mINFO    [0m | [36mDF[0m | [1mModel loaded[0m
19
clipping...
76
clipping...
[32m2025

In [20]:
#exclude 18, 19, 76 i 94