In [None]:
import sys
import os
sys.path.append("../lib")
from models import build_av_model
from generators import AVGenerator
import scipy.io.wavfile as wavfile
import utils
from tensorflow.keras.callbacks import EarlyStopping

import tensorflow as tf
import numpy as np
import math
print('TensorFlow Version: {}'.format(tf.__version__))
#Check for a GPU
if not tf.test.gpu_device_name():
    print('No GPU found. Please ensure you have installed TensorFlow correctly')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))

In [None]:
MIX_TRAIN = "../data/audio_train/mix"
MIX_TEST = "../data/audio_test/mix"
CRM_TRAIN = "../data/audio_train/crm"
CRM_TEST = "../data/audio_test/crm"

EMB_TRAIN = "../data/emb/train"
EMB_TEST = "../data/emb/test"

In [None]:
early_stop = EarlyStopping(monitor="val_loss", 
                           min_delta=0, 
                           patience=3, 
                           verbose=True, 
                           mode="auto", 
                           baseline=None, 
                           restore_best_weights=False)
callbacks = [early_stop]

In [None]:
epochs = 100
initial_epoch=0
batch_size = 6
n_speakers = 2
av_model = build_av_model(n_speakers)

In [None]:
mix_train_files = utils.get_files(MIX_TRAIN)
crm_train_files = utils.get_files(CRM_TRAIN)

mix_test_files = utils.get_files(MIX_TEST)
crm_test_files = utils.get_files(CRM_TEST)

train_emb_files = utils.get_files(EMB_TRAIN)
test_emb_files = utils.get_files(EMB_TEST)

In [None]:
train_generator = AVGenerator(mix_train_files, crm_train_files, train_emb_files, n_speakers, batch_size)
val_generator = AVGenerator(mix_test_files, crm_test_files, test_emb_files, n_speakers, batch_size)

In [None]:
av_model.fit(train_generator,
             validation_data=val_generator,
             epochs=epochs,
             verbose=True,
             callbacks=callbacks,
             initial_epoch=initial_epoch)

In [None]:
SAVE_MODEL_FOLDER = "../data/saved/models"
utils.make_dirs(SAVE_MODEL_FOLDER)
NAME_MODEL = "av_model"
av_model.save(os.path.join(SAVE_MODEL_FOLDER, "{}.h5".format(NAME_MODEL)))

In [None]:
SAVE_AUDIO_FOLDER = "../data/saved/audio"
utils.make_dirs(SAVE_AUDIO_FOLDER)
MIX_TEST_FILE = mix_test_files[3]

mix = np.load(MIX_TEST_FILE)
cleans_name = utils.get_clean_in_mix(MIX_TEST_FILE)
face_embs = np.zeros((1, 75, 1, 1792, n_speakers))
for i in range(n_speakers):
    face_embs[1, :, :, :, i] = np.load(os.path.join(EMB_TRAIN, "{}.npy".format(cleans_name[i])))

cRMs = av_model.predict([np.expand_dims(mix, axis=0), face_embs])
cRMs = cRMs[0]
for i in range(n_speakers):
    cRM = cRMs[ :, :, :, i]
    F = utils.icRM(mix,cRM)
    T = utils.istft(F)
    filename = str(i) + '.wav'
    wavfile.write(os.path.join(SAVE_AUDIO_FOLDER, filename), 16000, T)