In [None]:
import os
import numpy as np
from scipy import signal as sg
import array
import random
import wave


#Function for audio pre-processing
def pre_processing(data, Fs, down_sam):
    
    #Transform stereo into monoral
    if data.ndim == 2:
        wavdata = 0.5*data[:, 0] + 0.5*data[:, 1]
    else:
        wavdata = data
    
    #Down sampling and normalization of the wave
    if down_sam is not None:
        wavdata = sg.resample_poly(wavdata, down_sam, Fs)
        Fs = down_sam
    
    return wavdata, Fs

def cal_adjusted_rms(clean_rms, snr):
    a = float(snr) / 20
    noise_rms = clean_rms / (10**a) 
    return noise_rms
def cal_amp(wf):
    buffer = wf.readframes(wf.getnframes())
    amptitude = (np.frombuffer(buffer, dtype="int16")).astype(np.float64)
    return amptitude
def cal_rms(amp):
    return np.sqrt(np.mean(np.square(amp), axis=-1))

In [None]:
#Setup
down_sam = None        #Downsampling rate (Hz) [Default]None

#Define random seed
np.random.seed(seed=32)

#Repeat for each myu

sn_rates = ["-4.0", "-2.0", "0.0", "2.0", "4.0"]
class_pathes = ["babycry", "gunshot", "glassbreak"]

In [None]:
for class_path in class_pathes:

    signal_files = os.listdir("./norm_dataset/" + class_path)
    bgm_files = os.listdir("./norm_dataset/bgs-10s/")

    for i in range(len(signal_files)):
        
        clean_file = "./norm_dataset/" + class_path + "/" + signal_files[i]
        noise_file = "./norm_dataset/bgs-10s/" + bgm_files[i]

        clean_wav = wave.open(clean_file, "r")
        noise_wav = wave.open(noise_file, "r")

        clean_amp = cal_amp(clean_wav)
        noise_amp = cal_amp(noise_wav)

        new_noise_amp = noise_amp
        new_clean_amp = np.zeros_like(new_noise_amp)
        new_noisy_amp = np.zeros_like(new_noise_amp)

        start = random.randint(0, len(noise_amp)-len(clean_amp))
        end = start + len(clean_amp)

        clean_rms = cal_rms(clean_amp)

        split_noise_amp = noise_amp[start: end]
        noise_rms = cal_rms(split_noise_amp)

        for sn_rate in sn_rates:

            save_path = "./mix_dataset2/sn" + sn_rate + "/" + class_path + "/"
            
            os.mkdir( save_path + '{0:03d}'.format(i)) 

            output_noisy_file = save_path + '{0:03d}'.format(i) + "/input.wav"
            output_clean_file = save_path + '{0:03d}'.format(i) + "/truth.wav"
            output_noise_file = save_path + '{0:03d}'.format(i) + "/bgm.wav"

            snr = float(sn_rate)
            adjusted_noise_rms = cal_adjusted_rms(clean_rms, snr)

            new_clean_amp[start:end] = clean_amp
            new_noise_amp = noise_amp * (adjusted_noise_rms / noise_rms) 
            new_mixed_amp = new_clean_amp + new_noise_amp

            new_noise_amp = new_noise_amp * (32767/new_mixed_amp.max(axis=0))
            new_clean_amp = new_clean_amp * (32767/new_mixed_amp.max(axis=0))
            new_noisy_amp = new_mixed_amp * (32767/new_mixed_amp.max(axis=0))


            noisy_wave = wave.Wave_write(output_noisy_file)
            noisy_wave.setparams(clean_wav.getparams())
            noisy_wave.writeframes(array.array('h', new_noisy_amp.astype(np.int16)).tostring() )
            noisy_wave.close()

            clean_wave = wave.Wave_write(output_clean_file)
            clean_wave.setparams(clean_wav.getparams())
            clean_wave.writeframes(array.array('h', new_clean_amp.astype(np.int16)).tostring() )
            clean_wave.close()

            noise_wave = wave.Wave_write(output_noise_file)
            noise_wave.setparams(clean_wav.getparams())
            noise_wave.writeframes(array.array('h', new_noise_amp.astype(np.int16)).tostring() )
            noise_wave.close()