In [None]:
import sys
sys.path.append("../lib")
from models import build_ao_model
from generators import AudioGenerator
import scipy.io.wavfile as wavfile
import utils
from tensorflow.keras.callbacks import EarlyStopping

import tensorflow as tf
import numpy as np
import math
print('TensorFlow Version: {}'.format(tf.__version__))
#Check for a GPU
if not tf.test.gpu_device_name():
    print('No GPU found. Please ensure you have installed TensorFlow correctly')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
assert tf.config.experimental.get_memory_growth(physical_devices[0])

In [None]:
MIX_TRAIN = "../data/audio_train/mix"
MIX_TEST = "../data/audio_test/mix"
CRM_TRAIN = "../data/audio_train/crm"
CRM_TEST = "../data/audio_test/crm"

In [None]:
early_stop = EarlyStopping(monitor="val_loss", 
                           min_delta=0, 
                           patience=3, 
                           verbose=True, 
                           mode="auto", 
                           baseline=None, 
                           restore_best_weights=False)
callbacks = [early_stop]

In [None]:
epochs = 100
initial_epoch=0
batch_size = 4
n_speakers = 2
ao_model = build_ao_model(n_speakers)

In [None]:
mix_train_files = utils.get_files(MIX_TRAIN)
crm_train_files = utils.get_files(CRM_TRAIN)

mix_test_files = utils.get_files(MIX_TEST)
crm_test_files = utils.get_files(CRM_TEST)

In [None]:
Xdim = (298, 257, 2)
Ydim = (298, 257, 2, 2)
def prepare_data(mix_files, crm_files):
    X = np.empty((mix_files.size, *Xdim))
    y = np.empty((mix_files.size, *Ydim))
    for i, ID in enumerate(mix_files):
        X[i,] = np.load(ID)
        mix_filename = utils.basename(ID)
        cRMs = utils.find_paths_contains(mix_filename, crm_files)
        for j, cRM in enumerate(cRMs):
            y[i, :, :, :, j] = np.load(cRM)
    return X, y

In [None]:
x, y = prepare_data(mix_train_files[:10], crm_train_files)
train_ds = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_ds.repeat().shuffle(10).batch(batch_size)

In [None]:
ao_model.fit(train_dataset,
             epochs=epochs,
             verbose=True,
             callbacks=callbacks,
             steps_per_epoch=math.ceil(10 / batch_size))

In [None]:
train_generator = AudioGenerator(mix_train_files[100:], crm_train_files, n_speakers, batch_size)
val_generator = AudioGenerator(mix_test_files[10:], crm_test_files, n_speakers, batch_size)

In [None]:
ao_model.fit(train_generator,
             validation_data=val_generator,
             epochs=epochs,
             verbose=True,
             callbacks=callbacks,
             initial_epoch=initial_epoch)

In [None]:
ao_model.fit(train_generator,
             epochs=epochs,
             verbose=True,
             initial_epoch=initial_epoch)

In [None]:
ao_model.fit_generator(generator=train_generator,
             epochs=epochs,
             initial_epoch=initial_epoch)

In [None]:
mix = np.load(mix_test_files[5])
cRMs = ao_model.predict(np.expand_dims(mix, axis=0))
cRMs = cRMs[0]
for i in range(n_speakers):
    cRM = cRMs[:,:,:,i]
    F = utils.icRM(mix,cRM)
    T = utils.istft(F)
    filename = str(i) + '.wav'
    wavfile.write(filename, 16000, T)