In [1]:
import json
from pathlib import Path
from functools import partial

In [4]:
import tensorflow as tf

In [5]:
tf.logging.set_verbosity(tf.logging.DEBUG)

In [6]:
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.separator import Separator
from spleeter.audio.convertor import to_stereo
from spleeter.model import model_fn
from spleeter.model.provider import ModelProvider
from spleeter.dataset import get_training_dataset, get_validation_dataset

In [7]:
config_path = "config/voice_config.json"

In [8]:
with open(config_path) as f:
    params = json.load(f)

#### Создаем `estimator`

In [9]:
def _create_estimator(params):
    """ Creates estimator.

    :param params: TF params to build estimator from.
    :returns: Built estimator.
    """
    session_config = tf.compat.v1.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.85
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=params['model_dir'],
        params=params,
        config=tf.estimator.RunConfig(
            save_checkpoints_steps=params['save_checkpoints_steps'],
            tf_random_seed=params['random_seed'],
            save_summary_steps=params['save_summary_steps'],
            session_config=session_config,
            log_step_count_steps=100,
            keep_checkpoint_max=5))
    return estimator

In [10]:
estimator = _create_estimator(params)
estimator

INFO:tensorflow:Using config: {'_model_dir': 'voice_model', '_tf_random_seed': 3, '_save_summary_steps': 50, '_save_checkpoints_steps': 300, '_save_checkpoints_secs': None, '_session_config': gpu_options {
  per_process_gpu_memory_fraction: 0.85
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001BBA59306C8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


<tensorflow_estimator.python.estimator.estimator.Estimator at 0x1bba5930448>

#### Загрузчик тренировочных данных

In [11]:
def _create_train_spec(params, audio_adapter, audio_path):
    """ Creates train spec.

    :param params: TF params to build spec from.
    :returns: Built train spec.
    """
    input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
    train_spec = tf.estimator.TrainSpec(
        input_fn=input_fn,
        max_steps=params['train_max_steps'])
    return train_spec

In [12]:
audio_adapter = get_default_audio_adapter()

In [13]:
audio_path = str(Path('..') / '..' / "data" / "transformed" / "nr")

In [14]:
assert len(list(Path(audio_path).iterdir())) >= 2

In [15]:
train_spec = _create_train_spec(params, audio_adapter, audio_path)

#### Загрузчик валидационных данных

In [16]:
def _create_evaluation_spec(params, audio_adapter, audio_path):
    """ Setup eval spec evaluating ever n seconds

    :param params: TF params to build spec from.
    :returns: Built evaluation spec.
    """
    input_fn = partial(
        get_validation_dataset,
        params,
        audio_adapter,
        audio_path)
    evaluation_spec = tf.estimator.EvalSpec(
        input_fn=input_fn,
        steps=None,
        throttle_secs=params['throttle_secs'])
    return evaluation_spec

In [17]:
evaluation_spec = _create_evaluation_spec(params, audio_adapter, audio_path)

#### Обучение
(вся магия — внутри)

In [18]:
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)

INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 300 or save_checkpoints_secs None.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Apply unet for voice_spectrogram
INFO:tensorflow:Apply unet for noise_spectrogram
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into voice_model\model.ckpt.


INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\11faafcdf712444880d3964e62b3c84d\\mix.wav' from 0.0 to 15.0




INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\0dc3b36938a341749c05a5355dd4c1f2\\mix.wav' from 0.0 to 15.0




INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\60fbc287eaea4600bb1d8fb88eb62430\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\35b4f36726354da3b902fc885fa858a7\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\2228b51a7a68483eaeef283a5cf681b9\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\570cc46328994bd8b718d08c5ce325c8\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\14f14d62c9ae402faf8c0e166d054c0c\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\a143fd4cc7594600b0f939f4f67cb5b6\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Audio da

INFO:tensorflow:loss = 2.5766711, step = 0


INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\bc5eeb447aa540a6b129353d1a6c7638\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\source\\8a25e9a349fb44cb9b6d855cf0488648\\1.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\source\\374dccad1c644655b3fa7cf6072dc386\\2.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\mix\\78a17737a6b9435aa2d56366a00cc22c\\mix.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\source\\377edcb9c9624e43844065a6bbbc9e16\\2.wav' from 0.0 to 15.0
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\train\\source\\13db11e4bc4a4934ae9f1d75aaef9d6

INFO:tensorflow:global_step/sec: 0.325425
INFO:tensorflow:loss = 1.6475946, step = 100 (306.989 sec)
INFO:tensorflow:global_step/sec: 1.92474
INFO:tensorflow:loss = 1.6394701, step = 200 (51.955 sec)
INFO:tensorflow:Saving checkpoints for 300 into voice_model\model.ckpt.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Apply unet for voice_spectrogram
INFO:tensorflow:Apply unet for noise_spectrogram
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-12-05T04:30:11Z
INFO:tensorflow:Graph was finalized.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from voice_model\model.ckpt-300
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\test\\mix\\ccc0f5f83eab4172838dbd9820189421\\mix.wav' from 1.5 to 13.5
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\test\\mix\\ccf2a3d9eea04154ba14e05c4bf77ffb\\mix.wav' from 1.5 to 13.5
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\test\\mix\\cd1d24c1b72040f2b6548c281a722ac5\\mix.wav' from 1.5 to 13.5
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\test\\mix\\cd257893cc344a988564cc88c173cbf8\\mix.wav' from 1.5 to 13.5
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\test\\mix\\cdad48658f2b4207afb0caef69307106\\mix.wav' from 1.5 to 13.5
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Audio data loaded successfully
INFO:spleeter:Loading audio b'..\\..\\data\\transformed\\nr\\test\\mix\\cdd62e0705f24d678c0c7fdb53a07009\\mix.wav' from 1.5 to 13.5
INFO:spleeter:Loading audio 

INFO:tensorflow:Finished evaluation at 2019-12-05-04:31:18
INFO:tensorflow:Saving dict for global step 300: absolute_difference = 1.5165284, global_step = 300, loss = 1.5165284, noise_spectrogram = 0.94527346, voice_spectrogram = 0.571255
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 300: voice_model\model.ckpt-300
INFO:tensorflow:global_step/sec: 0.77153
INFO:tensorflow:loss = 1.6227119, step = 300 (129.612 sec)
INFO:tensorflow:global_step/sec: 1.95042
INFO:tensorflow:loss = 1.4468017, step = 400 (51.299 sec)
INFO:tensorflow:global_step/sec: 1.96854
INFO:tensorflow:loss = 1.6141794, step = 500 (50.772 sec)
INFO:tensorflow:Saving checkpoints for 600 into voice_model\model.ckpt.
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 1.81373
INFO:tensorflow:loss = 1.5832071, step = 600 (55.135 sec)
INFO:tensorflow:global_step/sec: 2.02189
INFO:tensorflow:loss = 1.5062239, step = 700 (49.459 sec)
INFO:tensorflo

KeyboardInterrupt: 

#### Сохраним полученные веса

In [20]:
ModelProvider.writeProbe(params['model_dir'])

Вышло на плато, поэтому остановил обучение (суммарно обучалось 8 часов 50 минут, за это время 61.9к шагов)  

![abs_loss](https://i.imgur.com/ycLcsGS.png)

![spectr_loss](https://i.imgur.com/Zt2hwho.png)