In [None]:
#https://github.com/yongxuUSTC/sednn

import os
import soundfile
import numpy as np
import argparse
import csv
import time
import matplotlib.pyplot as plt
from scipy import signal
import pickle
import cPickle
import h5py
from sklearn import preprocessing

import prepare_data as data_preparation
import config as configuration


def create_directory(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

        
def read_audio_file(file_path, target_sample_rate=None):
    (audio, sample_rate) = soundfile.read(file_path)
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    if target_sample_rate is not None and sample_rate != target_sample_rate:
        audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sample_rate)
        sample_rate = target_sample_rate
    return audio, sample_rate


def write_audio_file(file_path, audio_data, sample_rate):
    soundfile.write(file=file_path, data=audio_data, samplerate=sample_rate)

def generate_mixture_csv(arguments):
    workspace_directory = arguments.workspace_directory
    speech_directory = arguments.speech_directory
    noise_directory = arguments.noise_directory
    data_type = arguments.data_type
    amplification = arguments.amplification
    sample_rate = configuration.sample_rate

    speech_file_names = [file_name for file_name in os.listdir(speech_directory) if file_name.lower().endswith(".wav")]
    noise_file_names = [file_name for file_name in os.listdir(noise_directory) if file_name.lower().endswith(".wav")]

    random_state = np.random.RandomState(0)
    output_csv_path = os.path.join(workspace_directory, "mixture_csvs", "%s.csv" % data_type)
    data_preparation.create_folder(os.path.dirname(output_csv_path))

    count = 0
    csv_file = open(output_csv_path, 'w')
    csv_file.write("%s\t%s\t%s\t%s\n" % ("speech_file", "noise_file", "noise_start", "noise_end"))
    for speech_file_name in speech_file_names:
        # Read speech file.
        speech_file_path = os.path.join(speech_directory, speech_file_name)
        (speech_audio, _) = read_audio_file(speech_file_path)
        speech_length = len(speech_audio)

        # For training data, mix each speech with randomly selected #amplification noise files.
        if data_type == 'training':
            selected_noise_file_names = random_state.choice(noise_file_names, size=amplification, replace=False)
        # For testing data, mix each speech with all noise files.
        elif data_type == 'testing':
            selected_noise_file_names = noise_file_names
        else:
            raise Exception("data_type must be training | testing!")

        # Mix one speech with different noise files multiple times.
        for noise_file_name in selected_noise_file_names:
            noise_file_path = os.path.join(noise_directory, noise_file_name)
            (noise_audio, _) = read_audio_file(noise_file_path)

            noise_length = len(noise_audio)

            if noise_length <= speech_length:
                noise_start = 0
                noise_end = speech_length
            # If noise is longer than speech, then randomly select a segment of noise.
            else:
                noise_start = random_state.randint(0, noise_length - speech_length, size=1)[0]
                noise_end = noise_start + speech_length

            if count % 100 == 0:
                print(count)

            count += 1
            csv_file.write("%s\t%s\t%d\t%d\n" % (speech_file_name, noise_file_name, noise_start, noise_end))

    csv_file.close()
    print(output_csv_path)
    print("Create %s mixture CSV finished!" % data_type)
    
def calculate_features(args):
    workspace_path = args.workspace_path
    speech_directory = args.speech_directory
    noise_directory = args.noise_directory
    data_type = args.data_type
    snr = args.snr
    sample_rate = cfg.sample_rate
    
    # Open mixture CSV. 
    mixture_csv_path = os.path.join(workspace_path, "mixture_csvs", "%s.csv" % data_type)
    with open(mixture_csv_path, 'rb') as file:
        reader = csv.reader(file, delimiter='\t')
        data_list = list(reader)
    
    t1 = time.time()
    count = 0
    for i1 in range(1, len(data_list)):
        [speech_name, noise_name, noise_start, noise_end] = data_list[i1]
        noise_start = int(noise_start)
        noise_end = int(noise_end)
        
        # Read speech audio. 
        speech_path = os.path.join(speech_directory, speech_name)
        (speech_audio, _) = read_audio(speech_path, target_fs=sample_rate)
        
        # Read noise audio. 
        noise_path = os.path.join(noise_directory, noise_name)
        (noise_audio, _) = read_audio(noise_path, target_fs=sample_rate)
        
        # Repeat noise to the same length as speech. 
        if len(noise_audio) < len(speech_audio):
            repetitions = int(np.ceil(float(len(speech_audio)) / float(len(noise_audio)))
            noise_audio_extended = np.tile(noise_audio, repetitions)
            noise_audio = noise_audio_extended[0 : len(speech_audio)]
        # Truncate noise to the same length as speech. 
        else:
            noise_audio = noise_audio[noise_start : noise_end]
        
        # Scale speech to the given SNR. 
        scaling_factor = get_scaling_factor(speech_audio, noise_audio, snr=snr)
        speech_audio *= scaling_factor
        
        # Get normalized mixture, speech, noise. 
        (mixed_audio, speech_audio, noise_audio, alpha) = additive_mixing(speech_audio, noise_audio)

        # Write out mixed audio. 
        output_base_name = os.path.join("%s.%s" % 
            (os.path.splitext(speech_name)[0], os.path.splitext(noise_name)[0]))
        output_audio_path = os.path.join(workspace_path, "mixed_audios", "spectrogram", 
            data_type, "%ddb" % int(snr), "%s.wav" % output_base_name)
        create_directory(os.path.dirname(output_audio_path))
        write_audio_file(output_audio_path, mixed_audio, sample_rate)

        # Extract spectrogram. 
        mixed_complex_spectrogram = calculate_spectrogram(mixed_audio, mode='complex')
        speech_spectrogram = calculate_spectrogram(speech_audio, mode='magnitude')
        noise_spectrogram = calculate_spectrogram(noise_audio, mode='magnitude')

        # Write out features. 
        output_feature_path = os.path.join(workspace_path, "features", "spectrogram", 
            data_type, "%ddb" % int(snr), "%s.p" % output_base_name)
        create_directory(os.path.dirname(output_feature_path))
        data = [mixed_complex_spectrogram, speech_spectrogram, noise_spectrogram, alpha, output_base_name]
        cPickle.dump(data, open(output_feature_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL)
        
        # Print. 
        if count % 100 == 0:
            print(count)
            
        count += 1

    print("Extracting feature time: %s" % (time.time() - t1))

def root_mean_square(y):
    return np.sqrt(np.mean(np.abs(y) ** 2, axis=0, keepdims=False))

def get_scaling_factor(source, noise, snr, method='rms'):
    original_snr_rms_ratio = root_mean_square(source) / root_mean_square(noise)
    target_snr_rms_ratio =  10. ** (float(snr) / 20.)    # snr = 20 * lg(rms(source) / rms(noise))
    signal_scaling_factor = target_snr_rms_ratio / original_snr_rms_ratio
    return signal_scaling_factor
                              
def mix_sources(source1, source2):
    mixed_audio = source1 + source2
        
    alpha = 1. / np.max(np.abs(mixed_audio))
    mixed_audio *= alpha
    source1 *= alpha
    source2 *= alpha
    return mixed_audio, source1, source2, alpha
    
def calculate_spectrogram(audio, mode):
    window_length = cfg.window_length
    overlap = cfg.overlap
    hamming_window = np.hamming(window_length)
    [frequencies, time_points, spectrogram] = signal.spectral.spectrogram(
                    audio, 
                    window=hamming_window,
                    nperseg=window_length, 
                    noverlap=overlap, 
                    detrend=False, 
                    return_onesided=True, 
                    mode=mode) 
    spectrogram = spectrogram.T
    if mode == 'magnitude':
        spectrogram = spectrogram.astype(np.float32)
    elif mode == 'complex':
        spectrogram = spectrogram.astype(np.complex64)
    else:
        raise Exception("Incorrect mode!")
    return spectrogram
            
def prepare_features(args):
    workspace = args.workspace
    data_type = args.data_type
    snr = args.snr
    n_concat = args.n_concat
    n_hop = args.n_hop
    
    x_all = []  # (n_segments, n_concat, n_freq)
    y_all = []  # (n_segments, n_freq)
    
    cnt = 0
    t1 = time.time()
    
    # Load all features. 
    feature_directory = os.path.join(workspace, "features", "spectrogram", data_type, "%ddb" % int(snr))
    feature_names = os.listdir(feature_directory)
    
    for name in feature_names:
        # Load feature. 
        feature_path = os.path.join(feature_directory, name)
        data = cPickle.load(open(feature_path, 'rb'))
        [mixed_complex_spectrogram, speech_spectrogram, noise_spectrogram, alpha, name] = data
        mixed_spectrogram = np.abs(mixed_complex_spectrogram)

        # Pad the start and finish of the spectrogram with border values. 
        n_padding = (n_concat - 1) / 2
        mixed_spectrogram = pad_with_border(mixed_spectrogram, n_padding)
        speech_spectrogram = pad_with_border(speech_spectrogram, n_padding)
    
        # Cut input spectrogram to 3D segments with n_concat. 
        mixed_spectrogram_3d = segment_2d_to_3d(mixed_spectrogram, aggregation_number=n_concat, hop=n_hop)
        x_all.append(mixed_spectrogram_3d)
        
        # Cut target spectrogram and take the center frame of each 3D segment. 
        speech_spectrogram_3d = segment_2d_to_3d(speech_spectrogram, aggregation_number=n_concat, hop=n_hop)
        y = speech_spectrogram_3d[:, (n_concat - 1) / 2, :]
        y_all.append(y)
    
        # Print. 
        if cnt % 100 == 0:
            print(cnt)
            
        # if cnt == 3: break
        cnt += 1
        
    x_all = np.concatenate(x_all, axis=0)   # (n_segments, n_concat, n_freq)
    y_all = np.concatenate(y_all, axis=0)   # (n_segments, n_freq)
    
    x_all = apply_logarithm(x_all).astype(np.float32)
    y_all = apply_logarithm(y_all).astype(np.float32)
    
    # Write out data to .h5 file. 
    out_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "%ddb" % int(snr), "data.h5")
    create_folder(os.path.dirname(out_path))
    with h5py.File(out_path, 'w') as h5_file:
        h5_file.create_dataset('x', data=x_all)
        h5_file.create_dataset('y', data=y_all)
    
    print("Write out to %s" % out_path)
    print("Prepare features finished! %s s" % (time.time() - t1,))
    
def apply_logarithm(x):
    return np.log(x + 1e-08)
    
def segment_2d_to_3d(x, aggregation_number, hop):
    # Pad to at least one block. 
    len_x, n_features = x.shape
    if (len_x < aggregation_number):
        x = np.concatenate((x, np.zeros((aggregation_number - len_x, n_features))))
        
    # Segment 2D to 3D. 
    len_x = len(x)
    i = 0
    x3d = []
    while (i + aggregation_number <= len_x):
        x3d.append(x[i : i + aggregation_number])
        i += hop
    return np.array(x3d)

def pad_with_border(x, n_padding):
    x_pad_list = [x[0:1]] * n_padding + [x] + [x[-1:]] * n_padding
    return np.concatenate(x_pad_list, axis=0)
                              
def compute_data_scaler(args):
    workspace_directory = args.workspace
    data_type = args.data_type
    snr = args.snr
    
    # Load data. 
    t1 = time.time()
    hdf5_file_path = os.path.join(workspace_directory, "packed_features", "spectrogram", data_type, "%ddb" % int(snr), "data.h5")
    with h5py.File(hdf5_file_path, 'r') as h5_file:
        x_data = h5_file.get('x')     
        x_data = np.array(x_data)     # (n_segments, n_concat, n_freq)
    
    # Compute the scaler. 
    (n_segments, n_concat, n_freq) = x_data.shape
    x2d_data = x_data.reshape((n_segments * n_concat, n_freq))
    data_scaler = preprocessing.StandardScaler(with_mean=True, with_std=True).fit(x2d_data)
    print(data_scaler.mean_)
    print(data_scaler.scale_)
    
    # Write out the scaler. 
    scaler_file_path = os.path.join(workspace_directory, "packed_features", "spectrogram", data_type, "%ddb" % int(snr), "scaler.p")
    create_folder(os.path.dirname(scaler_file_path))
    pickle.dump(data_scaler, open(scaler_file_path, 'wb'))
    
    print("Save the scaler to %s" % scaler_file_path)
    print("Compute the scaler finished! %s s" % (time.time() - t1))
    
def scale_2d_data(x2d, scaler):
    return scaler.transform(x2d)
    
def scale_3d_data(x3d, scaler):
    (n_segments, n_concat, n_freq) = x3d.shape
    x2d_data = x3d.reshape((n_segments * n_concat, n_freq))
    x2d_data = scaler.transform(x2d_data)
    x3d_data = x2d_data.reshape((n_segments, n_concat, n_freq))
    return x3d_data
    
def inverse_scale_2d_data(x2d, scaler):
    return x2d * scaler.scale_[None, :] + scaler.mean_[None, :]
    
def load_data_from_hdf5(hdf5_file_path):
    with h5py.File(hdf5_file_path, 'r') as h5_file:
        x_data = h5_file.get('x')
        y_data = h5_file.get('y')
        x_data = np.array(x_data)     # (n_segments, n_concat, n_freq)
        y_data = np.array(y_data)     # (n_segments, n_freq)        
    return x_data, y_data

def np_mean_absolute_error(y_true, y_pred):
    return np.mean(np.abs(y_pred - y_true))
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest='mode')

    parser_create_mixture_csv = subparsers.add_parser('create_mixture_csv')
    parser_create_mixture_csv.add_argument('--workspace', type=str, required=True)
    parser_create_mixture_csv.add_argument('--speech_dir', type=str, required=True)
    parser_create_mixture_csv.add_argument('--noise_dir', type=str, required=True)
    parser_create_mixture_csv.add_argument('--data_type', type=str, required=True)
    parser_create_mixture_csv.add_argument('--magnification', type=int, default=1)

    parser_calculate_mixture_features = subparsers.add_parser('calculate_mixture_features')
    parser_calculate_mixture_features.add_argument('--workspace', type=str, required=True)
    parser_calculate_mixture_features.add_argument('--speech_dir', type=str, required=True)
    parser_calculate_mixture_features.add_argument('--noise_dir', type=str, required=True)
    parser_calculate_mixture_features.add_argument('--data_type', type=str, required=True)
    parser_calculate_mixture_features.add_argument('--snr', type=float, required=True)
    
    parser_pack_features = subparsers.add_parser('pack_features')
    parser_pack_features.add_argument('--workspace', type=str, required=True)
    parser_pack_features.add_argument('--data_type', type=str, required=True)
    parser_pack_features.add_argument('--snr', type=float, required=True)
    parser_pack_features.add_argument('--n_concat', type=int, required=True)
    parser_pack_features.add_argument('--n_hop', type=int, required=True)
    
    parser_compute_data_scaler = subparsers.add_parser('compute_data_scaler')
    parser_compute_data_scaler.add_argument('--workspace', type=str, required=True)
    parser_compute_data_scaler.add_argument('--data_type', type=str, required=True)
    parser_compute_data_scaler.add_argument('--snr', type=float, required=True)
    
    args = parser.parse_args()
    if args.mode == 'create_mixture_csv':
        create_mixture_csv(args)
    elif args.mode == 'calculate_mixture_features':
        calculate_mixture_features(args)
    elif args.mode == 'pack_features':
        pack_features(args)       
    elif args.mode == 'compute_data_scaler':
        compute_data_scaler(args)
    else:
        raise Exception("Error!")

In [None]:
import numpy as np

class DataGenerator(object):
    def __init__(self, batch_size, data_type, max_iterations=None):
        assert data_type in ['train', 'test']
        self._batch_size_ = batch_size
        self._data_type_ = data_type
        self._max_iterations_ = max_iterations
        
    def generate(self, x_data, y_data):
        x = x_data[0]
        y = y_data[0]
        batch_size = self._batch_size_
        num_samples = len(x)
        
        indices = np.arange(num_samples)
        np.random.shuffle(indices)
        
        iterations = 0
        epochs = 0
        pointer = 0
        while True:
            if (self._data_type_ == 'test') and (self._max_iterations_ is not None):
                if iterations == self._max_iterations_:
                    break
            iterations += 1
            if pointer >= num_samples:
                epochs += 1
                if (self._data_type_) == 'test' and (epochs == 1):
                    break
                pointer = 0
                np.random.shuffle(indices)                
 
            batch_indices = indices[pointer : min(pointer + batch_size, num_samples)]
            pointer += batch_size
            yield x[batch_indices], y[batch_indices]

In [None]:
import numpy as np
import numpy
import decimal

def recover_wav(predicted_abs_spectrum, ground_truth_spectrum, overlap, window_function, target_wav_length=None):
    x = real_to_complex(predicted_abs_spectrum, ground_truth_spectrum)
    x = half_to_whole(x)
    frames = ifft_to_wav(x)
    (n_frames, n_window) = frames.shape
    audio_signal = deframesig(frames=frames, siglen=0, frame_len=n_window, 
                   frame_step=n_window-overlap, winfunc=window_function)
    if target_wav_length:
        audio_signal = pad_or_truncate(audio_signal, target_wav_length)
    return audio_signal
    
def real_to_complex(predicted_abs_spectrum, ground_truth_spectrum):
    phase = np.angle(ground_truth_spectrum)
    complex_spectrum = predicted_abs_spectrum * np.exp(1j * phase)
    return complex_spectrum
    
def half_to_whole(x):
    return np.concatenate((x, np.fliplr(np.conj(x[:, 1:-1]))), axis=1)

def ifft_to_wav(x):
    return np.real(np.fft.ifft(x))

def pad_or_truncate(s, target_length):
    if len(s) >= target_length:
        s = s[0 : target_length]
    else:
        s = np.concatenate((s, np.zeros(target_length - len(s))))
    return s

def recover_gt_wav(x, overlap, window_function, target_wav_length=None):
    x = half_to_whole(x)
    frames = ifft_to_wav(x)
    (n_frames, n_window) = frames.shape
    audio_signal = deframesig(frames=frames, siglen=0, frame_len=n_window, 
                   frame_step=n_window-overlap, winfunc=window_function)
    if target_wav_length:
        audio_signal = pad_or_truncate(audio_signal, target_wav_length)
    return audio_signal

def deframesig(frames, siglen, frame_len, frame_step, winfunc=lambda x: numpy.ones((x,))):    
    frame_len = round_half_up(frame_len)
    frame_step = round_half_up(frame_step)
    numframes = numpy.shape(frames)[0]
    assert numpy.shape(frames)[1] == frame_len, '"frames" matrix is wrong size, 2nd dim is not equal to frame_len'
 
    indices = numpy.tile(numpy.arange(0,frame_len),(numframes,1)) + numpy.tile(numpy.arange(0,numframes*frame_step,frame_step),(frame_len,1)).T
    indices = numpy.array(indices,dtype=numpy.int32)
    padlen = (numframes-1)*frame_step + frame_len   
    
    if siglen <= 0:
        siglen = padlen
    
    reconstructed_signal = numpy.zeros((padlen,))
    window_correction = numpy.zeros((padlen,))
    win = winfunc(frame_len)
    
    for i in range(0, numframes):
        window_correction[indices[i,:]] = window_correction[indices[i,:]] + win + 1e-15 #add a little bit so it is never zero
        reconstructed_signal[indices[i,:]] = reconstructed_signal[indices[i,:]] + frames[i,:]
        
    reconstructed_signal = reconstructed_signal / window_correction
    return reconstructed_signal[0:siglen]
    
def round_half_up(number):
    return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))

In [None]:
import numpy as np
import os
import pickle
import cPickle
import h5py
import argparse
import time
import glob
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.optimizers import Adam
from keras.models import load_model

sample_rate = 16000
n_window = 512
n_overlap = 256 

def evaluate(model, generator, x, y):
    predicted_all, y_all = [], []
    
    # Inference in mini batch. 
    for (batch_x, batch_y) in generator.generate(xs=[x], ys=[y]):
        predicted = model.predict(batch_x)
        predicted_all.append(predicted)
        y_all.append(batch_y)
        
    # Concatenate mini batch prediction. 
    predicted_all = np.concatenate(predicted_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)
    
    # Compute loss. 
    loss = data_preparation.np_mean_absolute_error(y_all, predicted_all)
    return loss
    

def train(args):
    print(args)
    workspace = args.workspace
    training_snr = args.tr_snr
    testing_snr = args.te_snr
    learning_rate = args.lr
    
    # Load data. 
    t1 = time.time()
    training_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "%ddb" % int(training_snr), "data.h5")
    testing_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "%ddb" % int(testing_snr), "data.h5")
    (training_x, training_y) = data_preparation.load_hdf5(training_hdf5_path)
    (testing_x, testing_y) = data_preparation.load_hdf5(testing_hdf5_path)
    print(training_x.shape, training_y.shape)
    print(testing_x.shape, testing_y.shape)
    print("Load data time: %s s" % (time.time() - t1,))
    
    batch_size = 500
    print("%d iterations / epoch" % int(training_x.shape[0] / batch_size))
    
    # Scale data. 
    if True:
        t1 = time.time()
        scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "%ddb" % int(training_snr), "scaler.p")
        scaler = pickle.load(open(scaler_path, 'rb'))
        training_x = data_preparation.scale_on_3d(training_x, scaler)
        training_y = data_preparation.scale_on_2d(training_y, scaler)
        testing_x = data_preparation.scale_on_3d(testing_x, scaler)
        testing_y = data_preparation.scale_on_2d(testing_y, scaler)
        print("Scale data time: %s s" % (time.time() - t1,))
        
    # Debug plot. 
    if False:
        plt.matshow(training_x[0 : 1000, 0, :].T, origin='lower', aspect='auto', cmap='jet')
        plt.show()
        pause
        
    # Build model
    (_, n_concat, n_freq) = training_x.shape
    n_hidden = 2048
    
    model = Sequential()
    model.add(Flatten(input_shape=(n_concat, n_freq))
    model.add(Dense(n_hidden, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(n_hidden, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(n_hidden, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(n_freq, activation='linear'))
    model.summary()
    
    model.compile(loss='mean_absolute_error',
                  optimizer=Adam(lr=learning_rate))
    train_data_generator = DataGenerator(batch_size=batch_size, type='train')
    eval_test_data_generator = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100)
    eval_train_data_generator = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100)

    # Directories for saving models and training stats
    model_directory = os.path.join(workspace, "models", "%ddb" % int(tr_snr))
    pp_data.create_folder(model_directory)

    stats_directory = os.path.join(workspace, "training_stats", "%ddb" % int(tr_snr))
    pp_data.create_folder(stats_directory)

    # Print loss before training.
    iteration = 0
    training_loss = eval(model, eval_train_data_generator, tr_x, tr_y)
    test_loss = eval(model, eval_test_data_generator, te_x, te_y)
    print("Iteration: %d, training_loss: %f, test_loss: %f" % (iteration, training_loss, test_loss)

    # Save out training stats.
    stat_dict = {'iteration': iteration,
                 'training_loss': training_loss,
                 'test_loss': test_loss}
    stat_path = os.path.join(stats_directory, "%diters.p" % iteration)
    cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL)

    # Train.
    t1 = time.time()
    for (batch_x, batch_y) in train_data_generator.generate(xs=[tr_x], ys=[tr_y]):
        loss = model.train_on_batch(batch_x, batch_y)
        iteration += 1

        # Validate and save training stats.
        if iteration % 1000 == 0:
            training_loss = eval(model, eval_train_data_generator, tr_x, tr_y)
            test_loss = eval(model, eval_test_data_generator, te_x, te_y)
            print("Iteration: %d, training_loss: %f, test_loss: %f" % (iteration, training_loss, test_loss)

            # Save out training stats.
            stat_dict = {'iteration': iteration,
                         'training_loss': training_loss,
                         'test_loss': test_loss}
            stat_path = os.path.join(stats_directory, "%diters.p" % iteration)
            cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL)

        # Save model.
        if iteration % 5000 == 0:
            model_path = os.path.join(model_directory, "md_%diters.h5" % iteration)
            model.save(model_path)
            print("Saved model to %s" % model_path)

        if iteration == 10001:
            break

    print("Training time: %s s" % (time.time() - t1))

def inference(args):
    print(args)
    workspace = args.workspace
    tr_snr = args.tr_snr
    te_snr = args.te_snr
    n_concat = args.n_concat
    iteration = args.iteration
    
    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    fs = cfg.sample_rate
    scale = True
    
    # Load model. 
    model_path = os.path.join(workspace, "models", "%ddb" % int(tr_snr), "md_%diters.h5" % iteration)
    model = load_model(model_path)
    
    # Load scaler. 
    scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "%ddb" % int(tr_snr), "scaler.p")
    scaler = pickle.load(open(scaler_path, 'rb'))
    
    # Load test data. 
    feat_dir = os.path.join(workspace, "features", "spectrogram", "test", "%ddb" % int(te_snr))
    file_names = os.listdir(feat_dir)

    for (cnt, file_name) in enumerate(file_names):
        # Load feature. 
        feat_path = os.path.join(feat_dir, file_name)
        data = cPickle.load(open(feat_path, 'rb'))
        [mixed_complex_x, speech_x, noise_x, alpha, file_name] = data
        mixed_x = np.abs(mixed_complex_x)
        
        # Process data. 
        n_pad = (n_concat - 1) // 2
        mixed_x = pp_data.pad_with_border(mixed_x, n_pad)
        mixed_x = pp_data.log_sp(mixed_x)
        speech_x = pp_data.log_sp(speech_x)
        
        # Scale data. 
        if scale:
            mixed_x = pp_data.scale_on_2d(mixed_x, scaler)
            speech_x = pp_data.scale_on_2d(speech_x, scaler)
        
        # Cut input spectrogram to 3D segments with n_concat. 
        mixed_x_3d = pp_data.mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=1)
        
        # Predict. 
        pred = model.predict(mixed_x_3d)
        print(cnt, file_name)
        
        # Inverse scale. 
        if scale:
            mixed_x = pp_data.inverse_scale_on_2d(mixed_x, scaler)
            speech_x = pp_data.inverse_scale_on_2d(speech_x, scaler)
            pred = pp_data.inverse_scale_on_2d(pred, scaler)
        
        # Debug plot. 
        if args.visualize:
            fig, axs = plt.subplots(3, 1, sharex=False)
            axs[0].matshow(mixed_x.T, origin='lower', aspect='auto', cmap='jet')
            axs[1].matshow(speech_x.T, origin='lower', aspect='auto', cmap='jet')
            axs[2].matshow(pred.T, origin='lower', aspect='auto', cmap='jet')
            axs[0].set_title("%ddb mixture log spectrogram" % int(te_snr))
            axs[1].set_title("Clean speech log spectrogram")
            axs[2].set_title("Enhanced speech log spectrogram")
            for j1 in range(3):
                axs[j1].xaxis.tick_bottom()
            plt.tight_layout()
            plt.show()

        # Recover enhanced wav. 
        pred_sp = np.exp(pred)
        s = recover_wav(pred_sp, mixed_complex_x, n_overlap, np.hamming)
        s *= np.sqrt((np.hamming(n_window)**2).sum())   # Scaler to compensate the amplitude 
                                                        # change after spectrogram and IFFT. 
        
        # Write out enhanced wav. 
        out_path = os.path.join(workspace, "enh_wavs", "test", "%ddb" % int(te_snr), "%s.enh.wav" % file_name)
        pp_data.create_folder(os.path.dirname(out_path))
        pp_data.write_audio(out_path, s, fs) 
                  
if __name__ == '__main__':
    argument_parser = argparse.ArgumentParser()
    subparsers = argument_parser.add_subparsers(dest='mode')

    train_parser = subparsers.add_parser('train')
    train_parser.add_argument('--workspace', type=str, required=True)
    train_parser.add_argument('--training_snr', type=float, required=True)
    train_parser.add_argument('--testing_snr', type=float, required=True)
    train_parser.add_argument('--learning_rate', type=float, required=True)
    
    inference_parser = subparsers.add_parser('inference')
    inference_parser.add_argument('--workspace', type=str, required=True)
    inference_parser.add_argument('--training_snr', type=float, required=True)
    inference_parser.add_argument('--testing_snr', type=float, required=True)
    inference_parser.add_argument('--n_concat', type=int, required=True)
    inference_parser.add_argument('--iteration', type=int, required=True)
    inference_parser.add_argument('--visualize', action='store_true', default=False)
    
    calculate_pesq_parser = subparsers.add_parser('calculate_pesq')
    calculate_pesq_parser.add_argument('--workspace', type=str, required=True)
    calculate_pesq_parser.add_argument('--speech_dir', type=str, required=True)
    calculate_pesq_parser.add_argument('--testing_snr', type=float, required=True)
    
    arguments = argument_parser.parse_args()
    
    if arguments.mode == 'train':
        train(arguments)
    elif arguments.mode == 'inference':
        inference(arguments)
    elif arguments.mode == 'calculate_pesq':
        calculate_pesq(arguments)
    else:
        raise Exception("Error!")